Skip to content

Commit 2b72ff3

Browse files
committed
updates
1 parent b982eca commit 2b72ff3

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
"T2IAdapter",
137137
"T5FilmDecoder",
138138
"Transformer2DModel",
139+
"TransformerTemporalModel",
139140
"UNet1DModel",
140141
"UNet2DConditionModel",
141142
"UNet2DModel",
@@ -649,6 +650,7 @@
649650
T2IAdapter,
650651
T5FilmDecoder,
651652
Transformer2DModel,
653+
TransformerTemporalModel,
652654
UNet1DModel,
653655
UNet2DConditionModel,
654656
UNet2DModel,

tests/models/test_modeling_common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,11 @@ def get_memory_usage(storage_dtype, compute_dtype):
14281428
@parameterized.expand([None, "foo"])
14291429
def test_works_with_automodel(self, subfolder):
14301430
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1431+
has_generator_in_inputs = False
1432+
if "generator" in inputs_dict:
1433+
has_generator_in_inputs = True
1434+
inputs_dict["generator"] = torch.manual_seed(0)
1435+
14311436
model = self.model_class(**config).eval()
14321437
model_cls_name = model.__class__.__name__
14331438
model.to(torch_device)
@@ -1438,14 +1443,18 @@ def test_works_with_automodel(self, subfolder):
14381443
with tempfile.TemporaryDirectory() as tmpdir:
14391444
path = os.path.join(tmpdir, subfolder) if subfolder else tmpdir
14401445
model.save_pretrained(path)
1441-
automodel = AutoModel.from_pretrained(tmpdir, subfolder=subfolder).to(torch_device)
1446+
automodel = AutoModel.from_pretrained(tmpdir, subfolder=subfolder).eval()
1447+
automodel.to(torch_device)
14421448

14431449
automodel_cls_name = automodel.__class__.__name__
14441450
self.assertTrue(model_cls_name == automodel_cls_name)
14451451
for p1, p2 in zip(model.parameters(), automodel.parameters()):
1446-
self.assertTrue(torch.equal(p1, p2))
1452+
if not (torch.isnan(p1).any() and torch.isnan(p2).any()):
1453+
self.assertTrue(torch.equal(p1, p2))
14471454

14481455
torch.manual_seed(0)
1456+
if has_generator_in_inputs:
1457+
inputs_dict["generator"] = torch.manual_seed(0)
14491458
output_automodel = model(**inputs_dict, return_dict=False)[0]
14501459

14511460
self.assertTrue(torch.allclose(output[0], output_automodel[0], atol=1e-5))

0 commit comments

Comments
 (0)