Skip to content

Commit a686c65

Browse files
committed
add test
1 parent 4b56b87 commit a686c65

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
AttnProcessorNPU,
4545
XFormersAttnProcessor,
4646
)
47+
from diffusers.models.auto_model import AUTO_MODEL_MAPPING, AutoModel
4748
from diffusers.training_utils import EMAModel
4849
from diffusers.utils import (
4950
SAFE_WEIGHTS_INDEX_NAME,
@@ -1568,6 +1569,52 @@ def run_forward(model):
15681569
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
15691570
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
15701571

1572+
def test_auto_model(self, expected_max_diff=5e-5):
1573+
if self.model_class not in list(AUTO_MODEL_MAPPING.values()):
1574+
self.skipTest(f"Skipping auto-model test: {self.model_class.__name__} is not in AUTO_MODEL_MAPPING")
1575+
1576+
if self.forward_requires_fresh_args:
1577+
model = self.model_class(**self.init_dict)
1578+
else:
1579+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1580+
model = self.model_class(**init_dict)
1581+
1582+
model = model.eval()
1583+
model = model.to(torch_device)
1584+
1585+
if hasattr(model, "set_default_attn_processor"):
1586+
model.set_default_attn_processor()
1587+
1588+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
1589+
model.save_pretrained(tmpdirname, safe_serialization=False)
1590+
1591+
auto_model = AutoModel.from_pretrained(tmpdirname)
1592+
if hasattr(auto_model, "set_default_attn_processor"):
1593+
auto_model.set_default_attn_processor()
1594+
1595+
auto_model = auto_model.eval()
1596+
auto_model = auto_model.to(torch_device)
1597+
1598+
with torch.no_grad():
1599+
if self.forward_requires_fresh_args:
1600+
output_original = model(**self.inputs_dict(0))
1601+
output_auto = auto_model(**self.inputs_dict(0))
1602+
else:
1603+
output_original = model(**inputs_dict)
1604+
output_auto = auto_model(**inputs_dict)
1605+
1606+
if isinstance(output_original, dict):
1607+
output_original = output_original.to_tuple()[0]
1608+
if isinstance(output_auto, dict):
1609+
output_auto = output_auto.to_tuple()[0]
1610+
1611+
max_diff = (output_original - output_auto).abs().max().item()
1612+
self.assertLessEqual(
1613+
max_diff,
1614+
expected_max_diff,
1615+
f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}",
1616+
)
1617+
15711618

15721619
@is_staging_test
15731620
class ModelPushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)