@@ -62,20 +62,20 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint
6262 adapter = LlamaStateDictAdapter (config )
6363
6464 # Load HF model
65- llama_model_hf = (
66- AutoModelForCausalLM .from_pretrained (
67- tiny_llama_checkpoint , attn_implementation = "eager" , torch_dtype = torch .bfloat16
68- )
69- .to ("cuda" )
70- .to (torch .bfloat16 ) # need to manual cast to bfloat16 since HF initialize weights in float32 dtype
71- )
65+ llama_model_hf = AutoModelForCausalLM .from_pretrained (
66+ pretrained_model_name_or_path = tiny_llama_checkpoint ,
67+ attn_implementation = "eager" ,
68+ torch_dtype = torch .bfloat16 ,
69+ ).to ("cuda" )
70+ llama_model_hf .eval ()
7271
7372 # Build custom model
7473 llama_model_custom = NeMoAutoModelForCausalLM .from_pretrained (
7574 pretrained_model_name_or_path = tiny_llama_checkpoint ,
7675 attn_implementation = "eager" ,
7776 torch_dtype = torch .bfloat16 ,
7877 ).to ("cuda" )
78+ llama_model_custom .eval ()
7979
8080 # Verify parameter counts match
8181 num_params_hf = sum (p .numel () for p in llama_model_hf .parameters ())
@@ -89,13 +89,23 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint
8989 custom_state_dict_from_hf = adapter .from_hf (hf_state_dict )
9090 llama_model_custom .load_state_dict (custom_state_dict_from_hf , strict = True )
9191
92+ s = adapter .to_hf (llama_model_custom .state_dict ())
93+
94+ for n1 , p1 in hf_state_dict .items ():
95+ p2 = s [n1 ]
96+ assert p1 .shape == p2 .shape , f"Parameter shape mismatch: { p1 .shape } != { p2 .shape } "
97+ assert p1 .dtype == p2 .dtype , f"Parameter dtype mismatch: { p1 .dtype } != { p2 .dtype } "
98+ assert p1 .device == p2 .device , f"Parameter device mismatch: { p1 .device } != { p2 .device } "
99+ assert p1 .requires_grad == p2 .requires_grad , f"Parameter requires_grad mismatch: { p1 .requires_grad } != { p2 .requires_grad } "
100+ assert torch .allclose (p1 , p2 , atol = 1e-5 , rtol = 1e-5 ), f"Parameter mismatch: { p1 } != { p2 } "
101+
92102 # Generate test inputs
93103 input_ids = torch .randint (0 , config .vocab_size , (1 , 10 )).to ("cuda" )
94104 attention_mask = torch .ones ((1 , 10 )).to ("cuda" )
95105
96106 # Compare HF → Custom outputs
97107 with torch .no_grad ():
98- output_hf = llama_model_hf (input_ids , attention_mask )
108+ output_hf = llama_model_hf (input_ids . clone () , attention_mask . clone () )
99109 output_custom = llama_model_custom (input_ids , attention_mask )
100110
101111 np .testing .assert_allclose (
@@ -111,13 +121,12 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint
111121 hf_state_dict_from_custom = adapter .to_hf (custom_state_dict )
112122
113123 # Create new HF model and load converted state dict
114- llama_model_hf_converted = (
115- AutoModelForCausalLM .from_pretrained (
116- tiny_llama_checkpoint , attn_implementation = "eager" , torch_dtype = torch .bfloat16
117- )
118- .to ("cuda" )
119- .to (torch .bfloat16 )
120- )
124+ llama_model_hf_converted = AutoModelForCausalLM .from_pretrained (
125+ tiny_llama_checkpoint ,
126+ attn_implementation = "eager" ,
127+ torch_dtype = torch .bfloat16
128+ ).to ("cuda" )
129+ llama_model_hf_converted .eval ()
121130 llama_model_hf_converted .load_state_dict (hf_state_dict_from_custom , strict = True )
122131
123132 # Compare Custom → HF outputs
@@ -191,6 +200,7 @@ def test_export_custom_to_hf_checkpoint(self, tiny_llama_checkpoint):
191200 attn_implementation = "eager" ,
192201 torch_dtype = torch .bfloat16 ,
193202 ).to ("cuda" )
203+ llama_model_custom .eval ()
194204
195205 # Generate test input
196206 input_ids = torch .randint (0 , config .vocab_size , (1 , 10 )).to ("cuda" )
@@ -204,15 +214,12 @@ def test_export_custom_to_hf_checkpoint(self, tiny_llama_checkpoint):
204214 llama_model_custom .save_pretrained_hf_format (export_path )
205215
206216 # Load from saved HF checkpoint
207- llama_model_hf_loaded = (
208- AutoModelForCausalLM .from_pretrained (
209- export_path ,
210- attn_implementation = "eager" ,
211- torch_dtype = torch .bfloat16 ,
212- )
213- .to ("cuda" )
214- .to (torch .bfloat16 )
215- )
217+ llama_model_hf_loaded = AutoModelForCausalLM .from_pretrained (
218+ export_path ,
219+ attn_implementation = "eager" ,
220+ torch_dtype = torch .bfloat16 ,
221+ ).to ("cuda" )
222+ llama_model_hf_loaded .eval ()
216223
217224 # Compare outputs
218225 with torch .no_grad ():
0 commit comments