@@ -19,27 +19,27 @@ def estimate_autoregressive_vram(
19
19
max_seq_len : int ,
20
20
batch_size : int = 1 ,
21
21
dtype = torch .float16 ,
22
- intermediate_factor : float = 4.0 ,
22
+ intermediate_factor : float = 4.0 ,
23
23
device = torch .device ('cuda' )
24
24
) -> bool :
25
-
25
+
26
26
dtype_size = torch .finfo (dtype ).bits // 8
27
27
kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size
28
28
29
- # we only calculate hidden states in cuda graphs, so we don't care about the output logits
30
- input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
31
-
29
+ # we only calculate hidden states in cuda graphs, so we don't care about the output logits
30
+ input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
31
+
32
32
# rough calculation for activation sizes
33
33
intermediate_bytes = intermediate_factor * output_bytes
34
-
34
+
35
35
total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes
36
-
36
+
37
37
# get vram info
38
38
free_vram = get_free_memory (device )
39
39
minimum_vram = minimum_inference_memory ()
40
-
40
+
41
41
enough_vram = free_vram - minimum_vram >= total_estimated
42
-
42
+
43
43
return enough_vram
44
44
45
45
class TopKLogits :
@@ -64,7 +64,7 @@ def __init__(self, temperature: float):
64
64
def __call__ (self , scores : torch .FloatTensor ) -> torch .FloatTensor :
65
65
scores_processed = scores / self .temperature
66
66
return scores_processed
67
-
67
+
68
68
class TopPLogitsWarper :
69
69
def __init__ (self , top_p : float , filter_value : float = - float ("Inf" ), min_tokens_to_keep : int = 1 ):
70
70
top_p = float (top_p )
@@ -175,7 +175,7 @@ def from_model_config(cls, config_dict: dict, **kwargs) -> GenerationConfig:
175
175
176
176
config_dict = {key : value for key , value in config_dict .items () if value is not None }
177
177
valid_fields = {f .name for f in fields (cls )}
178
-
178
+
179
179
filtered_args = {k : v for k , v in {** config_dict , ** kwargs }.items () if k in valid_fields }
180
180
181
181
generation_config = cls (** filtered_args )
@@ -216,7 +216,7 @@ def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
216
216
self .model .cache_config = self .cache_config
217
217
218
218
self .kv_caches = {
219
-
219
+
220
220
length : StaticCache (
221
221
config = self .cache_config ,
222
222
max_batch_size = self .cache_config .max_batch ,
@@ -234,8 +234,8 @@ def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
234
234
235
235
# cuda graphs only help if input shapes are constant
236
236
if (
237
- device == "cuda"
238
- and hasattr (model , "capture_model" )
237
+ device == "cuda"
238
+ and hasattr (model , "capture_model" )
239
239
and self .model .cache_implementation == "static"
240
240
and self .model .use_kv_buckets
241
241
and enough_vram
@@ -247,7 +247,7 @@ def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
247
247
@torch .inference_mode ()
248
248
def generate (self , input_ids : Optional [torch .LongTensor ] = None , max_new_length : int = 1024 , min_new_length = 0 ,
249
249
top_k : int = 50 , top_p : float = 1.0 , temperature : float = 1.0 , do_sample : bool = False , seed = None , ** kwargs ):
250
-
250
+
251
251
if seed is not None :
252
252
torch_generator = torch .Generator (device = input_ids .device ).manual_seed (seed )
253
253
else :
@@ -335,7 +335,7 @@ def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length:
335
335
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
336
336
if not hasattr (self .model , "_sample" ):
337
337
raise ValueError ("Model doesn't support AutoRegressive Generation!" )
338
-
338
+
339
339
self ._prepare_kv_caches ()
340
340
341
341
result = self .model ._sample (
@@ -347,7 +347,7 @@ def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length:
347
347
)
348
348
349
349
return result
350
-
350
+
351
351
def _prepare_kv_caches (self ):
352
352
for kv_cache in self .kv_caches .values ():
353
353
kv_cache .reset ()
@@ -357,13 +357,13 @@ def get_generation_mode(self, config: GenerationConfig):
357
357
return GenerationSampling .BEAM_SAMPLING
358
358
else :
359
359
return GenerationSampling .GREEDY_SEARCH
360
-
360
+
361
361
def _prepare_generated_length (
362
362
self ,
363
363
generation_config : GenerationConfig ,
364
364
input_ids_length ,
365
365
):
366
-
366
+
367
367
""" max_length = user_input_id_tokens + generation_max_length """
368
368
369
369
if generation_config .max_new_length is not None :
@@ -374,11 +374,11 @@ def _prepare_generated_length(
374
374
generation_config .min_length = generation_config .min_new_length + input_ids_length
375
375
376
376
return generation_config
377
-
377
+
378
378
def _get_cache (
379
379
self , cache_implementation : str , batch_size : int , max_cache_len : int , device : torch .device , model_kwargs
380
380
) -> Cache :
381
-
381
+
382
382
assert cache_implementation == "static" , f"Only 'static' cache is supported, got { cache_implementation } "
383
383
384
384
cache_cls : Cache = NEED_SETUP_CACHE_CLASSES_MAPPING [cache_implementation ]
@@ -412,7 +412,7 @@ def _get_cache(
412
412
413
413
return self .model ._cache
414
414
415
-
415
+
416
416
def _prepare_cache_for_generation (
417
417
self ,
418
418
generation_config : GenerationConfig ,
@@ -466,7 +466,7 @@ def _prepare_generation_config(self, generation_config: GenerationConfig, **kwar
466
466
model_kwargs = generation_config .update (** kwargs )
467
467
468
468
return generation_config , model_kwargs
469
-
469
+
470
470
def _validate_generated_length (self , generation_config : GenerationConfig , input_ids_length ):
471
471
"""Performs validation related to the resulting generated length"""
472
472
@@ -498,7 +498,7 @@ def _validate_generated_length(self, generation_config: GenerationConfig, input_
498
498
f" the maximum possible length ({ generation_config .max_length } )." + min_length_error_suffix ,
499
499
UserWarning ,
500
500
)
501
-
501
+
502
502
def _expand_inputs_for_generation (
503
503
self ,
504
504
expand_size : int = 1 ,
@@ -526,13 +526,13 @@ def _expand_dict_for_generation(dict_to_expand):
526
526
model_kwargs = _expand_dict_for_generation (model_kwargs )
527
527
528
528
return input_ids , model_kwargs
529
-
529
+
530
530
def _prepare_special_tokens (
531
531
self ,
532
532
generation_config : GenerationConfig ,
533
533
device : Optional [Union [torch .device , str ]] = None ,
534
534
):
535
-
535
+
536
536
def _tensor_or_none (token , device = None ):
537
537
if token is None :
538
538
return token
@@ -564,7 +564,7 @@ def _prepare_attention_mask_for_generation(
564
564
generation_config : GenerationConfig ,
565
565
model_kwargs : dict [str , Any ],
566
566
) -> torch .LongTensor :
567
-
567
+
568
568
pad_token_id = generation_config ._pad_token_tensor
569
569
eos_token_id = generation_config ._eos_token_tensor
570
570
@@ -593,12 +593,12 @@ def _prepare_attention_mask_for_generation(
593
593
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~ can_infer_attention_mask
594
594
)
595
595
return attention_mask
596
-
596
+
597
597
def auto_sample (node , patcher , input_ids , max_new_length = 1024 , min_new_length = 0 , top_k = 50 , top_p = 1.0 , temperature = 1.0 , do_sample = False , seed = None , ** kwargs ):
598
598
# to work with BaseModel
599
599
if hasattr (patcher , "model" ) and hasattr (patcher .model , "diffusion_model" ):
600
600
model = patcher .model .diffusion_model
601
-
601
+
602
602
if node ._cached_autoregressive_sampler is None or node ._cached_autoregressive_sampler .model is not model :
603
603
if model .device != patcher .load_device :
604
604
model = model .to (patcher .load_device , dtype = model .dtype )
@@ -610,7 +610,7 @@ def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0,
610
610
kwargs .update ({k : v for k , v in input_ids .items () if k != "input_ids" })
611
611
else :
612
612
main_input_ids = input_ids
613
-
613
+
614
614
device = node ._cached_autoregressive_sampler .device
615
615
616
616
main_input_ids = main_input_ids .to (device )
0 commit comments