Skip to content

Commit 456cfc5

Browse files
committed
fix
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent d3d0f8f commit 456cfc5

File tree

1 file changed

+31
-24
lines changed

1 file changed

+31
-24
lines changed

tests/unit_tests/models/llama/test_llama_custom_model.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,20 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint
6262
adapter = LlamaStateDictAdapter(config)
6363

6464
# Load HF model
65-
llama_model_hf = (
66-
AutoModelForCausalLM.from_pretrained(
67-
tiny_llama_checkpoint, attn_implementation="eager", torch_dtype=torch.bfloat16
68-
)
69-
.to("cuda")
70-
.to(torch.bfloat16) # need to manual cast to bfloat16 since HF initialize weights in float32 dtype
71-
)
65+
llama_model_hf = AutoModelForCausalLM.from_pretrained(
66+
pretrained_model_name_or_path=tiny_llama_checkpoint,
67+
attn_implementation="eager",
68+
torch_dtype=torch.bfloat16,
69+
).to("cuda")
70+
llama_model_hf.eval()
7271

7372
# Build custom model
7473
llama_model_custom = NeMoAutoModelForCausalLM.from_pretrained(
7574
pretrained_model_name_or_path=tiny_llama_checkpoint,
7675
attn_implementation="eager",
7776
torch_dtype=torch.bfloat16,
7877
).to("cuda")
78+
llama_model_custom.eval()
7979

8080
# Verify parameter counts match
8181
num_params_hf = sum(p.numel() for p in llama_model_hf.parameters())
@@ -89,13 +89,23 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint
8989
custom_state_dict_from_hf = adapter.from_hf(hf_state_dict)
9090
llama_model_custom.load_state_dict(custom_state_dict_from_hf, strict=True)
9191

92+
s = adapter.to_hf(llama_model_custom.state_dict())
93+
94+
for n1, p1 in hf_state_dict.items():
95+
p2 = s[n1]
96+
assert p1.shape == p2.shape, f"Parameter shape mismatch: {p1.shape} != {p2.shape}"
97+
assert p1.dtype == p2.dtype, f"Parameter dtype mismatch: {p1.dtype} != {p2.dtype}"
98+
assert p1.device == p2.device, f"Parameter device mismatch: {p1.device} != {p2.device}"
99+
assert p1.requires_grad == p2.requires_grad, f"Parameter requires_grad mismatch: {p1.requires_grad} != {p2.requires_grad}"
100+
assert torch.allclose(p1, p2, atol=1e-5, rtol=1e-5), f"Parameter mismatch: {p1} != {p2}"
101+
92102
# Generate test inputs
93103
input_ids = torch.randint(0, config.vocab_size, (1, 10)).to("cuda")
94104
attention_mask = torch.ones((1, 10)).to("cuda")
95105

96106
# Compare HF → Custom outputs
97107
with torch.no_grad():
98-
output_hf = llama_model_hf(input_ids, attention_mask)
108+
output_hf = llama_model_hf(input_ids.clone(), attention_mask.clone())
99109
output_custom = llama_model_custom(input_ids, attention_mask)
100110

101111
np.testing.assert_allclose(
@@ -111,13 +121,12 @@ def test_model_matches_hf_with_adapter_bidirectional(self, tiny_llama_checkpoint
111121
hf_state_dict_from_custom = adapter.to_hf(custom_state_dict)
112122

113123
# Create new HF model and load converted state dict
114-
llama_model_hf_converted = (
115-
AutoModelForCausalLM.from_pretrained(
116-
tiny_llama_checkpoint, attn_implementation="eager", torch_dtype=torch.bfloat16
117-
)
118-
.to("cuda")
119-
.to(torch.bfloat16)
120-
)
124+
llama_model_hf_converted = AutoModelForCausalLM.from_pretrained(
125+
tiny_llama_checkpoint,
126+
attn_implementation="eager",
127+
torch_dtype=torch.bfloat16
128+
).to("cuda")
129+
llama_model_hf_converted.eval()
121130
llama_model_hf_converted.load_state_dict(hf_state_dict_from_custom, strict=True)
122131

123132
# Compare Custom → HF outputs
@@ -191,6 +200,7 @@ def test_export_custom_to_hf_checkpoint(self, tiny_llama_checkpoint):
191200
attn_implementation="eager",
192201
torch_dtype=torch.bfloat16,
193202
).to("cuda")
203+
llama_model_custom.eval()
194204

195205
# Generate test input
196206
input_ids = torch.randint(0, config.vocab_size, (1, 10)).to("cuda")
@@ -204,15 +214,12 @@ def test_export_custom_to_hf_checkpoint(self, tiny_llama_checkpoint):
204214
llama_model_custom.save_pretrained_hf_format(export_path)
205215

206216
# Load from saved HF checkpoint
207-
llama_model_hf_loaded = (
208-
AutoModelForCausalLM.from_pretrained(
209-
export_path,
210-
attn_implementation="eager",
211-
torch_dtype=torch.bfloat16,
212-
)
213-
.to("cuda")
214-
.to(torch.bfloat16)
215-
)
217+
llama_model_hf_loaded = AutoModelForCausalLM.from_pretrained(
218+
export_path,
219+
attn_implementation="eager",
220+
torch_dtype=torch.bfloat16,
221+
).to("cuda")
222+
llama_model_hf_loaded.eval()
216223

217224
# Compare outputs
218225
with torch.no_grad():

0 commit comments

Comments
 (0)