@@ -536,7 +536,6 @@ def apply_model_infrastructure(
536536 autopipeline = None ,
537537 parallelize_fn = None ,
538538 compile_config = None ,
539- model_name_or_path = None ,
540539 load_base_model = False ,
541540 cache_dir = None ,
542541 pretrained_model_name_or_path = "" ,
@@ -570,7 +569,7 @@ def apply_model_infrastructure(
570569 autopipeline: AutoPipeline instance for pipeline parallelism. Default: None
571570 parallelize_fn: Function to apply parallelization (EP + FSDP2). Default: None
572571 compile_config: Compilation configuration. Default: None
573- model_name_or_path : Model name or path for checkpoint loading. Default: None
572+ pretrained_model_name_or_path : Model name or path for checkpoint loading. Default: None
574573 load_base_model: Whether to load base model weights (True for from_pretrained). Default: False
575574 cache_dir: Cache directory for model weights. Default: None
576575 **_kwargs: Additional keyword arguments (ignored, allows passing extra kwargs)
@@ -659,7 +658,7 @@ def apply_model_infrastructure(
659658 mp ,
660659 device ,
661660 cache_dir ,
662- model_name_or_path ,
661+ pretrained_model_name_or_path ,
663662 lora_a_init ,
664663 load_base_model = load_base_model ,
665664 )
@@ -976,7 +975,6 @@ def _retry(**override):
976975 is_meta_device = is_meta_device ,
977976 device = device ,
978977 compile_config = compile_config ,
979- model_name_or_path = pretrained_model_name_or_path ,
980978 load_base_model = True ,
981979 cache_dir = kwargs .get ("cache_dir" , TRANSFORMERS_CACHE ),
982980 )
@@ -1180,7 +1178,7 @@ def _retry(**override):
11801178 is_meta_device = is_meta_device ,
11811179 device = device ,
11821180 compile_config = compile_config ,
1183- model_name_or_path = getattr (config , "name_or_path" ),
1181+ pretrained_model_name_or_path = getattr (config , "name_or_path" ),
11841182 load_base_model = False ,
11851183 cache_dir = kwargs .get ("cache_dir" , TRANSFORMERS_CACHE ),
11861184 )
0 commit comments