@@ -25,31 +25,20 @@ def save_config_to_constant_methods(
2525):
2626 # Initialize metadata with values from model config
2727 metadata = {
28- "get_dtype" : 5 if config .torch_dtype == torch .float16 else 6 ,
2928 "get_bos_id" : getattr (config , "bos_token_id" , None ),
3029 "get_eos_id" : getattr (config , "eos_token_id" , None ),
31- "get_head_dim" : getattr (config , "head_dim" , None ),
32- "get_n_kv_heads" : getattr (config , "num_key_value_heads" , None ),
33- "get_n_layers" : getattr (config , "num_hidden_layers" , None ),
3430 "get_vocab_size" : getattr (config , "vocab_size" , None ),
35- "get_max_batch_size" : 1 ,
3631 "get_max_seq_len" : getattr (config , "max_position_embeddings" , None ),
3732 "use_kv_cache" : getattr (generation_config , "use_cache" , None ),
38- "sliding_window" : getattr (config , "sliding_window" , None ),
39- "decoder_start_token_id" : getattr (config , "decoder_start_token_id" , None ),
40- "use_sdpa_with_kv_cache" : "custom_sdpa" in config ._attn_implementation ,
33+ "use_sdpa_with_kv_cache" : False ,
4134 }
4235
4336 # Safely access fields from generation_config if it exists
4437 if generation_config is not None :
4538 # Check for cache_config and its attributes
4639 cache_config = getattr (generation_config , "cache_config" , None )
4740 if cache_config is not None :
48- max_batch_size = getattr (cache_config , "batch_size" , None )
4941 max_seq_len = getattr (cache_config , "max_cache_len" , None )
50-
51- if max_batch_size is not None :
52- metadata ["get_max_batch_size" ] = max_batch_size
5342 if max_seq_len is not None :
5443 metadata ["get_max_seq_len" ] = max_seq_len
5544
0 commit comments