@@ -143,8 +143,9 @@ def get_examples(is_multimodal: bool) -> Dict[str, Any]:
143143 return data
144144
145145
146- def test_convert_precision (args , hf_model , mg_model , template ):
147- torch_dtype = args .test_convert_dtype
146+ def test_convert_precision (args , hf_model , mg_model , template , test_convert_dtype = None ):
147+ if test_convert_dtype is None :
148+ test_convert_dtype = getattr (args , 'test_convert_dtype' , torch .float32 )
148149 template .set_mode ('train' )
149150 _test_params_sum (mg_model )
150151
@@ -166,7 +167,7 @@ def test_convert_precision(args, hf_model, mg_model, template):
166167 ignore_modules = (model_arch .vision_tower + model_arch .aligner ) if is_multimodal else []
167168 hf_modules = _find_modules (hf_model , ignore_modules = ignore_modules )
168169 with torch .inference_mode (), _model_cpu_forward_context (
169- hf_modules , torch_dtype , share_embedding = share_embedding ):
170+ hf_modules , test_convert_dtype , share_embedding = share_embedding ):
170171 hf_inputs .pop ('text_position_ids' , None )
171172 hf_logits = hf_model (** hf_inputs ).logits
172173 hf_logits = hf_logits .to ('cuda' )
@@ -195,8 +196,8 @@ def test_convert_precision(args, hf_model, mg_model, template):
195196 if n .endswith ('router' ):
196197 m .to (mg_dtype )
197198 with torch .inference_mode (), _model_cpu_forward_context (
198- mg_modules , torch_dtype , 'cuda' , share_embedding = share_embedding , target_device = mg_device ):
199- mg_logits = forward_step_helper (mg_model , mg_inputs , dtype = torch_dtype )
199+ mg_modules , test_convert_dtype , 'cuda' , share_embedding = share_embedding , target_device = mg_device ):
200+ mg_logits = forward_step_helper (args , mg_model , mg_inputs , dtype = test_convert_dtype )
200201 if args .tensor_model_parallel_size > 1 and args .task_type != 'seq_cls' :
201202 from megatron .core .tensor_parallel .mappings import gather_from_tensor_model_parallel_region
202203 if mg_logits is not None :
0 commit comments