Skip to content

Commit 31adeb4

Browse files
authored
[Tests] fix sharding tests (#8764)
fix sharding tests
1 parent a7b9634 commit 31adeb4

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,10 @@ def __init__(
415415

416416
if set_W_to_weight:
417417
# to delete later
418+
del self.weight
418419
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
419-
420420
self.weight = self.W
421+
del self.W
421422

422423
def forward(self, x):
423424
if self.log:

tests/models/autoencoders/test_models_vae.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,10 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
361361
forward_requires_fresh_args = True
362362

363363
def inputs_dict(self, seed=None):
364-
generator = torch.Generator("cpu")
365-
if seed is not None:
366-
generator.manual_seed(0)
364+
if seed is None:
365+
generator = torch.Generator("cpu").manual_seed(0)
366+
else:
367+
generator = torch.Generator("cpu").manual_seed(seed)
367368
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
368369

369370
return {"sample": image, "generator": generator}

tests/models/test_modeling_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,11 +905,13 @@ def test_sharded_checkpoints(self):
905905
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
906906
self.assertTrue(actual_num_shards == expected_num_shards)
907907

908-
new_model = self.model_class.from_pretrained(tmp_dir)
908+
new_model = self.model_class.from_pretrained(tmp_dir).eval()
909909
new_model = new_model.to(torch_device)
910910

911911
torch.manual_seed(0)
912+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
912913
new_output = new_model(**inputs_dict)
914+
913915
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
914916

915917
@require_torch_gpu
@@ -940,6 +942,7 @@ def test_sharded_checkpoints_device_map(self):
940942
new_model = new_model.to(torch_device)
941943

942944
torch.manual_seed(0)
945+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
943946
new_output = new_model(**inputs_dict)
944947
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
945948

0 commit comments

Comments
 (0)