Skip to content

Commit 44f64d2

Browse files
committed
update
1 parent 474a248 commit 44f64d2

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

tests/models/test_modeling_common.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,14 +1528,16 @@ def test_fn(storage_dtype, compute_dtype):
15281528
test_fn(torch.float8_e5m2, torch.float32)
15291529
test_fn(torch.float8_e4m3fn, torch.bfloat16)
15301530

1531+
@torch.no_grad()
15311532
def test_layerwise_casting_inference(self):
15321533
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
15331534

15341535
torch.manual_seed(0)
15351536
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1536-
model = self.model_class(**config).eval()
1537-
model = model.to(torch_device)
1538-
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1537+
model = self.model_class(**config)
1538+
model.eval()
1539+
model.to(torch_device)
1540+
base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy()
15391541

15401542
def check_linear_dtype(module, storage_dtype, compute_dtype):
15411543
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -1706,10 +1708,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17061708
if not self.model_class._supports_group_offloading:
17071709
pytest.skip("Model does not support group offloading.")
17081710

1709-
torch.manual_seed(0)
1710-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1711-
model = self.model_class(**init_dict)
1712-
17131711
torch.manual_seed(0)
17141712
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17151713
model = self.model_class(**init_dict)
@@ -1725,7 +1723,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17251723
**additional_kwargs,
17261724
)
17271725
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1728-
assert has_safetensors, "No safetensors found in the directory."
1726+
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
17291727
_ = model(**inputs_dict)[0]
17301728

17311729
def test_auto_model(self, expected_max_diff=5e-5):

0 commit comments

Comments
 (0)