|
44 | 44 | AttnProcessorNPU, |
45 | 45 | XFormersAttnProcessor, |
46 | 46 | ) |
| 47 | +from diffusers.models.auto_model import AUTO_MODEL_MAPPING, AutoModel |
47 | 48 | from diffusers.training_utils import EMAModel |
48 | 49 | from diffusers.utils import ( |
49 | 50 | SAFE_WEIGHTS_INDEX_NAME, |
@@ -1568,6 +1569,52 @@ def run_forward(model): |
1568 | 1569 | self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) |
1569 | 1570 | self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) |
1570 | 1571 |
|
| 1572 | + def test_auto_model(self, expected_max_diff=5e-5): |
| 1573 | + if self.model_class not in list(AUTO_MODEL_MAPPING.values()): |
| 1574 | + self.skipTest(f"Skipping auto-model test: {self.model_class.__name__} is not in AUTO_MODEL_MAPPING") |
| 1575 | + |
| 1576 | + if self.forward_requires_fresh_args: |
| 1577 | + model = self.model_class(**self.init_dict) |
| 1578 | + else: |
| 1579 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 1580 | + model = self.model_class(**init_dict) |
| 1581 | + |
| 1582 | + model = model.eval() |
| 1583 | + model = model.to(torch_device) |
| 1584 | + |
| 1585 | + if hasattr(model, "set_default_attn_processor"): |
| 1586 | + model.set_default_attn_processor() |
| 1587 | + |
| 1588 | + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: |
| 1589 | + model.save_pretrained(tmpdirname, safe_serialization=False) |
| 1590 | + |
| 1591 | + auto_model = AutoModel.from_pretrained(tmpdirname) |
| 1592 | + if hasattr(auto_model, "set_default_attn_processor"): |
| 1593 | + auto_model.set_default_attn_processor() |
| 1594 | + |
| 1595 | + auto_model = auto_model.eval() |
| 1596 | + auto_model = auto_model.to(torch_device) |
| 1597 | + |
| 1598 | + with torch.no_grad(): |
| 1599 | + if self.forward_requires_fresh_args: |
| 1600 | + output_original = model(**self.inputs_dict(0)) |
| 1601 | + output_auto = auto_model(**self.inputs_dict(0)) |
| 1602 | + else: |
| 1603 | + output_original = model(**inputs_dict) |
| 1604 | + output_auto = auto_model(**inputs_dict) |
| 1605 | + |
| 1606 | + if isinstance(output_original, dict): |
| 1607 | + output_original = output_original.to_tuple()[0] |
| 1608 | + if isinstance(output_auto, dict): |
| 1609 | + output_auto = output_auto.to_tuple()[0] |
| 1610 | + |
| 1611 | + max_diff = (output_original - output_auto).abs().max().item() |
| 1612 | + self.assertLessEqual( |
| 1613 | + max_diff, |
| 1614 | + expected_max_diff, |
| 1615 | + f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}", |
| 1616 | + ) |
| 1617 | + |
1571 | 1618 |
|
1572 | 1619 | @is_staging_test |
1573 | 1620 | class ModelPushToHubTester(unittest.TestCase): |
|
0 commit comments