Skip to content

Commit 1f026ad

Browse files
committed
update
1 parent 5a47442 commit 1f026ad

File tree

12 files changed

+1000
-12
lines changed

12 files changed

+1000
-12
lines changed

tests/models/test_modeling_common.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,9 @@ def test_local_files_only_with_sharded_checkpoint(self):
317317
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
318318
)
319319

320-
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
321-
"Model parameters don't match!"
322-
)
320+
assert all(
321+
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
322+
), "Model parameters don't match!"
323323

324324
# Remove a shard file
325325
cached_shard_file = try_to_load_from_cache(
@@ -335,9 +335,9 @@ def test_local_files_only_with_sharded_checkpoint(self):
335335

336336
# Verify error mentions the missing shard
337337
error_msg = str(context.exception)
338-
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
339-
f"Expected error about missing shard, got: {error_msg}"
340-
)
338+
assert (
339+
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
340+
), f"Expected error about missing shard, got: {error_msg}"
341341

342342
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
343343
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
@@ -354,9 +354,9 @@ def test_one_request_upon_cached(self):
354354
)
355355

356356
download_requests = [r.method for r in m.request_history]
357-
assert download_requests.count("HEAD") == 3, (
358-
"3 HEAD requests one for config, one for model, and one for shard index file."
359-
)
357+
assert (
358+
download_requests.count("HEAD") == 3
359+
), "3 HEAD requests one for config, one for model, and one for shard index file."
360360
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
361361

362362
with requests_mock.mock(real_http=True) as m:
@@ -368,9 +368,9 @@ def test_one_request_upon_cached(self):
368368
)
369369

370370
cache_requests = [r.method for r in m.request_history]
371-
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
372-
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
373-
)
371+
assert (
372+
"HEAD" == cache_requests[0] and len(cache_requests) == 2
373+
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
374374

375375
def test_weight_overwrite(self):
376376
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .common import ModelTesterMixin
2+
from .single_file import SingleFileTesterMixin

tests/models/testing_utils/attention.py

Whitespace-only changes.
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import tempfile
17+
from typing import Dict, List, Tuple
18+
19+
import pytest
20+
import torch
21+
22+
from ...testing_utils import torch_device
23+
24+
25+
class ModelTesterMixin:
26+
"""
27+
Base mixin class for model testing with common test methods.
28+
29+
Expected class attributes to be set by subclasses:
30+
- model_class: The model class to test
31+
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
32+
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
33+
34+
Expected methods to be implemented by subclasses:
35+
- get_init_dict(): Returns dict of arguments to initialize the model
36+
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
37+
"""
38+
39+
model_class = None
40+
base_precision = 1e-3
41+
42+
def get_init_dict(self):
43+
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
44+
45+
def get_dummy_inputs(self):
46+
raise NotImplementedError(
47+
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
48+
)
49+
50+
def check_device_map_is_respected(self, model, device_map):
51+
"""Helper method to check if device map is correctly applied to model parameters."""
52+
for param_name, param in model.named_parameters():
53+
# Find device in device_map
54+
while len(param_name) > 0 and param_name not in device_map:
55+
param_name = ".".join(param_name.split(".")[:-1])
56+
if param_name not in device_map:
57+
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
58+
59+
param_device = device_map[param_name]
60+
if param_device in ["cpu", "disk"]:
61+
assert param.device == torch.device(
62+
"meta"
63+
), f"Expected device 'meta' for {param_name}, got {param.device}"
64+
else:
65+
assert param.device == torch.device(
66+
param_device
67+
), f"Expected device {param_device} for {param_name}, got {param.device}"
68+
69+
def test_from_save_pretrained(self, expected_max_diff=5e-5):
70+
"""Test that model can be saved and loaded with save_pretrained/from_pretrained."""
71+
model = self.model_class(**self.get_init_dict())
72+
model.to(torch_device)
73+
model.eval()
74+
75+
with tempfile.TemporaryDirectory() as tmpdirname:
76+
model.save_pretrained(tmpdirname)
77+
new_model = self.model_class.from_pretrained(tmpdirname)
78+
new_model.to(torch_device)
79+
80+
with torch.no_grad():
81+
image = model(**self.get_dummy_inputs())
82+
83+
if isinstance(image, dict):
84+
image = image.to_tuple()[0]
85+
86+
new_image = new_model(**self.get_dummy_inputs())
87+
88+
if isinstance(new_image, dict):
89+
new_image = new_image.to_tuple()[0]
90+
91+
max_diff = (image - new_image).abs().max().item()
92+
assert (
93+
max_diff <= expected_max_diff
94+
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
95+
96+
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
97+
"""Test save_pretrained/from_pretrained with variant parameter."""
98+
model = self.model_class(**self.get_init_dict())
99+
model.to(torch_device)
100+
model.eval()
101+
102+
with tempfile.TemporaryDirectory() as tmpdirname:
103+
model.save_pretrained(tmpdirname, variant="fp16")
104+
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
105+
106+
# non-variant cannot be loaded
107+
with pytest.raises(OSError) as exc_info:
108+
self.model_class.from_pretrained(tmpdirname)
109+
110+
# make sure that error message states what keys are missing
111+
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
112+
113+
new_model.to(torch_device)
114+
115+
with torch.no_grad():
116+
image = model(**self.get_dummy_inputs())
117+
if isinstance(image, dict):
118+
image = image.to_tuple()[0]
119+
120+
new_image = new_model(**self.get_dummy_inputs())
121+
122+
if isinstance(new_image, dict):
123+
new_image = new_image.to_tuple()[0]
124+
125+
max_diff = (image - new_image).abs().max().item()
126+
assert (
127+
max_diff <= expected_max_diff
128+
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
129+
130+
def test_from_save_pretrained_dtype(self):
131+
"""Test save_pretrained/from_pretrained preserves dtype correctly."""
132+
model = self.model_class(**self.get_init_dict())
133+
model.to(torch_device)
134+
model.eval()
135+
136+
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
137+
if torch_device == "mps" and dtype == torch.bfloat16:
138+
continue
139+
with tempfile.TemporaryDirectory() as tmpdirname:
140+
model.to(dtype)
141+
model.save_pretrained(tmpdirname)
142+
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
143+
assert new_model.dtype == dtype
144+
if (
145+
hasattr(self.model_class, "_keep_in_fp32_modules")
146+
and self.model_class._keep_in_fp32_modules is None
147+
):
148+
new_model = self.model_class.from_pretrained(
149+
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
150+
)
151+
assert new_model.dtype == dtype
152+
153+
def test_determinism(self, expected_max_diff=1e-5):
154+
"""Test that model outputs are deterministic across multiple forward passes."""
155+
model = self.model_class(**self.get_init_dict())
156+
model.to(torch_device)
157+
model.eval()
158+
159+
with torch.no_grad():
160+
first = model(**self.get_dummy_inputs())
161+
if isinstance(first, dict):
162+
first = first.to_tuple()[0]
163+
164+
second = model(**self.get_dummy_inputs())
165+
if isinstance(second, dict):
166+
second = second.to_tuple()[0]
167+
168+
# Remove NaN values and compute max difference
169+
first_flat = first.flatten()
170+
second_flat = second.flatten()
171+
172+
# Filter out NaN values
173+
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
174+
first_filtered = first_flat[mask]
175+
second_filtered = second_flat[mask]
176+
177+
max_diff = torch.abs(first_filtered - second_filtered).max().item()
178+
assert (
179+
max_diff <= expected_max_diff
180+
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
181+
182+
def test_output(self, expected_output_shape=None):
183+
"""Test that model produces output with expected shape."""
184+
model = self.model_class(**self.get_init_dict())
185+
model.to(torch_device)
186+
model.eval()
187+
188+
inputs_dict = self.get_dummy_inputs()
189+
with torch.no_grad():
190+
output = model(**inputs_dict)
191+
192+
if isinstance(output, dict):
193+
output = output.to_tuple()[0]
194+
195+
assert output is not None, "Model output is None"
196+
assert (
197+
output.shape == expected_output_shape
198+
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
199+
200+
def test_model_from_pretrained(self):
201+
"""Test that model loaded from pretrained matches original model."""
202+
model = self.model_class(**self.get_init_dict())
203+
model.to(torch_device)
204+
model.eval()
205+
206+
# test if the model can be loaded from the config
207+
# and has all the expected shape
208+
with tempfile.TemporaryDirectory() as tmpdirname:
209+
model.save_pretrained(tmpdirname, safe_serialization=False)
210+
new_model = self.model_class.from_pretrained(tmpdirname)
211+
new_model.to(torch_device)
212+
new_model.eval()
213+
214+
# check if all parameters shape are the same
215+
for param_name in model.state_dict().keys():
216+
param_1 = model.state_dict()[param_name]
217+
param_2 = new_model.state_dict()[param_name]
218+
assert (
219+
param_1.shape == param_2.shape
220+
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
221+
222+
with torch.no_grad():
223+
output_1 = model(**self.get_dummy_inputs())
224+
225+
if isinstance(output_1, dict):
226+
output_1 = output_1.to_tuple()[0]
227+
228+
output_2 = new_model(**self.get_dummy_inputs())
229+
230+
if isinstance(output_2, dict):
231+
output_2 = output_2.to_tuple()[0]
232+
233+
assert (
234+
output_1.shape == output_2.shape
235+
), f"Output shape mismatch. Original: {output_1.shape}, loaded: {output_2.shape}"
236+
237+
def test_outputs_equivalence(self):
238+
"""Test that dict and tuple outputs are equivalent."""
239+
240+
def set_nan_tensor_to_zero(t):
241+
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
242+
# Track progress in https://github.com/pytorch/pytorch/issues/77764
243+
device = t.device
244+
if device.type == "mps":
245+
t = t.to("cpu")
246+
t[t != t] = 0
247+
return t.to(device)
248+
249+
def recursive_check(tuple_object, dict_object):
250+
if isinstance(tuple_object, (List, Tuple)):
251+
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
252+
recursive_check(tuple_iterable_value, dict_iterable_value)
253+
elif isinstance(tuple_object, Dict):
254+
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
255+
recursive_check(tuple_iterable_value, dict_iterable_value)
256+
elif tuple_object is None:
257+
return
258+
else:
259+
assert torch.allclose(
260+
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
261+
), (
262+
"Tuple and dict output are not equal. Difference:"
263+
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
264+
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
265+
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
266+
)
267+
268+
model = self.model_class(**self.get_init_dict())
269+
model.to(torch_device)
270+
model.eval()
271+
272+
with torch.no_grad():
273+
outputs_dict = model(**self.get_dummy_inputs())
274+
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
275+
276+
recursive_check(outputs_tuple, outputs_dict)
277+
278+
def test_model_config_to_json_string(self):
279+
"""Test model config can be serialized to JSON string."""
280+
model = self.model_class(**self.get_init_dict())
281+
282+
json_string = model.config.to_json_string()
283+
assert isinstance(json_string, str), "Config to_json_string should return a string"
284+
assert len(json_string) > 0, "JSON string should not be empty"
285+
286+
def test_keep_in_fp32_modules(self):
287+
r"""
288+
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
289+
Also ensures if inference works.
290+
"""
291+
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
292+
pytest.skip("Model does not have _keep_in_fp32_modules")
293+
294+
fp32_modules = self.model_class._keep_in_fp32_modules
295+
296+
for torch_dtype in [torch.bfloat16, torch.float16]:
297+
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, torch_dtype=torch_dtype).to(
298+
torch_device
299+
)
300+
for name, param in model.named_parameters():
301+
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
302+
assert param.data == torch.float32
303+
else:
304+
assert param.data == torch_dtype

0 commit comments

Comments
 (0)