@@ -76,14 +76,24 @@ class AutoModelForCausalLMFactory(ModelFactory):
7676 "max_position_embeddings" : 1024 ,
7777 }
7878
79+ def _get_max_position_embeddings_config (self ) -> Dict [str , Any ]:
80+ """Get the max position embeddings config for the model."""
81+ return {
82+ "max_position_embeddings" : self .max_seq_len ,
83+ }
84+
7985 def __init__ (self , * args , ** kwargs ):
8086 super ().__init__ (* args , ** kwargs )
8187
8288 self ._quant_config : Optional [Dict ] = None
8389
8490 # Ingest defaults for tokenizer and model kwargs
8591 self .tokenizer_kwargs = deep_merge_dicts (self ._tokenizer_defaults , self .tokenizer_kwargs )
86- self .model_kwargs = deep_merge_dicts (self ._model_defaults , self .model_kwargs )
92+ self .model_kwargs = deep_merge_dicts (
93+ self ._model_defaults ,
94+ self .model_kwargs ,
95+ self ._get_max_position_embeddings_config (),
96+ )
8797
8898 # special handling for torch_dtype in model_kwargs since HF does not correctly update
8999 # torch_dtype string to an actual torch.dtype object (only with default)
@@ -344,6 +354,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
344354 },
345355 }
346356
357+ def _get_max_position_embeddings_config (self ) -> Dict [str , Any ]:
358+ """Get the max position embeddings config for the model."""
359+ return {
360+ "max_position_embeddings" : self .max_seq_len ,
361+ "text_config" : {
362+ "max_position_embeddings" : self .max_seq_len ,
363+ },
364+ }
365+
347366 @property
348367 def automodel_from_config (self ):
349368 return AutoModelForImageTextToText .from_config
0 commit comments