1+ from tqdm import tqdm
12from typing import Optional , Dict , Any , Tuple , List
3+ import gc
24
35import torch
4- from transformers .cache_utils import DynamicCache
6+ from transformers .cache_utils import Cache , DynamicCache , OffloadedCache
7+
58
69
710class OmniGenCache (DynamicCache ):
8- def __init__ (self ,
9- num_tokens_for_img : int , offload_kv_cache : bool = False ) -> None :
11+ def __init__ (self ,
12+ num_tokens_for_img : int ,
13+ offload_kv_cache : bool = False ) -> None :
1014 if not torch .cuda .is_available ():
11- raise RuntimeError (
12- "OmniGenCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!" )
15+ # print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
16+ # offload_kv_cache = False
17+ raise RuntimeError ("OffloadedCache can only be used with a GPU. If there is no GPU, you need to set use_kv_cache=False, which will result in longer inference time!" )
1318 super ().__init__ ()
1419 self .original_device = []
1520 self .prefetch_stream = torch .cuda .Stream ()
@@ -25,17 +30,19 @@ def prefetch_layer(self, layer_idx: int):
2530 self .key_cache [layer_idx ] = self .key_cache [layer_idx ].to (device , non_blocking = True )
2631 self .value_cache [layer_idx ] = self .value_cache [layer_idx ].to (device , non_blocking = True )
2732
33+
2834 def evict_previous_layer (self , layer_idx : int ):
2935 "Moves the previous layer cache to the CPU"
3036 if len (self ) > 2 :
3137 # We do it on the default stream so it occurs after all earlier computations on these tensors are done
32- if layer_idx == 0 :
38+ if layer_idx == 0 :
3339 prev_layer_idx = - 1
3440 else :
3541 prev_layer_idx = (layer_idx - 1 ) % len (self )
3642 self .key_cache [prev_layer_idx ] = self .key_cache [prev_layer_idx ].to ("cpu" , non_blocking = True )
3743 self .value_cache [prev_layer_idx ] = self .value_cache [prev_layer_idx ].to ("cpu" , non_blocking = True )
3844
45+
3946 def __getitem__ (self , layer_idx : int ) -> List [Tuple [torch .Tensor ]]:
4047 "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
4148 if layer_idx < len (self ):
@@ -44,12 +51,12 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
4451 torch .cuda .current_stream ().synchronize ()
4552 self .evict_previous_layer (layer_idx )
4653 # Load current layer cache to its original device if not already there
47- # original_device = self.original_device[layer_idx]
54+ original_device = self .original_device [layer_idx ]
4855 # self.prefetch_stream.synchronize(original_device)
49- self . prefetch_stream .synchronize ()
56+ torch . cuda .synchronize (self . prefetch_stream )
5057 key_tensor = self .key_cache [layer_idx ]
5158 value_tensor = self .value_cache [layer_idx ]
52-
59+
5360 # Prefetch the next layer
5461 self .prefetch_layer ((layer_idx + 1 ) % len (self ))
5562 else :
@@ -58,13 +65,13 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
5865 return (key_tensor , value_tensor )
5966 else :
6067 raise KeyError (f"Cache only has { len (self )} layers, attempted to access layer with index { layer_idx } " )
61-
68+
6269 def update (
63- self ,
64- key_states : torch .Tensor ,
65- value_states : torch .Tensor ,
66- layer_idx : int ,
67- cache_kwargs : Optional [Dict [str , Any ]] = None ,
70+ self ,
71+ key_states : torch .Tensor ,
72+ value_states : torch .Tensor ,
73+ layer_idx : int ,
74+ cache_kwargs : Optional [Dict [str , Any ]] = None ,
6875 ) -> Tuple [torch .Tensor , torch .Tensor ]:
6976 """
7077 Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -85,13 +92,13 @@ def update(
8592 raise ValueError ("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache." )
8693 elif len (self .key_cache ) == layer_idx :
8794 # only cache the states for condition tokens
88- key_states = key_states [..., :- (self .num_tokens_for_img + 1 ), :]
89- value_states = value_states [..., :- (self .num_tokens_for_img + 1 ), :]
95+ key_states = key_states [..., :- (self .num_tokens_for_img + 1 ), :]
96+ value_states = value_states [..., :- (self .num_tokens_for_img + 1 ), :]
9097
91- # Update the number of seen tokens
98+ # Update the number of seen tokens
9299 if layer_idx == 0 :
93100 self ._seen_tokens += key_states .shape [- 2 ]
94-
101+
95102 self .key_cache .append (key_states )
96103 self .value_cache .append (value_states )
97104 self .original_device .append (key_states .device )
0 commit comments