@@ -426,12 +426,9 @@ def _retry(**override):
426426 _download_model_weights (hf_config , pretrained_model_name_or_path )
427427 logger .info (f"Using custom model implementation for { architectures [0 ]} " )
428428 kwargs .pop ("trust_remote_code" , None )
429+ # TODO: restore weights after initialization.
429430 with local_torch_dtype (torch_dtype , ModelRegistry .model_arch_name_to_cls [architectures [0 ]].__name__ ):
430- return ModelRegistry .model_arch_name_to_cls [architectures [0 ]](
431- hf_config ,
432- * model_args ,
433- ** kwargs ,
434- )
431+ return ModelRegistry .model_arch_name_to_cls [architectures [0 ]](hf_config )
435432
436433 # 3. fallback to parent class
437434 model = None
@@ -564,7 +561,10 @@ def _retry(**override):
564561
565562 # handle model_id passed as config
566563 if isinstance (config , str ):
567- config = AutoConfig .from_pretrained (config , trust_remote_code = kwargs .get ("trust_remote_code" , False ))
564+ config = AutoConfig .from_pretrained (
565+ config , trust_remote_code = kwargs .get ("trust_remote_code" , False ),
566+ attn_implementation = attn_implementation ,
567+ )
568568 # 1. if force_hf is True, we will use the parent class to load and return the model as is
569569 if force_hf :
570570 return cls ._from_config_parent_class (
@@ -578,7 +578,8 @@ def _retry(**override):
578578 # 2. If we have a custom model implementation available, we prioritize that over HF
579579 architectures = get_architectures (config )
580580 if len (architectures ) > 0 and architectures [0 ] in ModelRegistry .model_arch_name_to_cls :
581- return ModelRegistry .model_arch_name_to_cls [architectures [0 ]](config , * model_args , ** kwargs )
581+ with local_torch_dtype (torch_dtype , ModelRegistry .model_arch_name_to_cls [architectures [0 ]].__name__ ):
582+ return ModelRegistry .model_arch_name_to_cls [architectures [0 ]](config )
582583
583584 # 3. fallback to parent class
584585 model = None
0 commit comments