diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py index 15fe8e02e00d..7ab79a0bb857 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py @@ -110,8 +110,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None: self.patch_size = patch_size self.patch_method = patch_method - self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False) - self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False) + wavelets = _WAVELETS.get(patch_method).clone() + arange = torch.arange(wavelets.shape[0]) + + self.register_buffer("wavelets", wavelets, persistent=False) + self.register_buffer("_arange", arange, persistent=False) def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor: dtype = hidden_states.dtype @@ -185,12 +188,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar"): self.patch_size = patch_size self.patch_method = patch_method - self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False) - self.register_buffer( - "_arange", - torch.arange(_WAVELETS[patch_method].shape[0]), - persistent=False, - ) + wavelets = _WAVELETS.get(patch_method).clone() + arange = torch.arange(wavelets.shape[0]) + + self.register_buffer("wavelets", wavelets, persistent=False) + self.register_buffer("_arange", arange, persistent=False) def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor: device = hidden_states.device diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index e8b41ddbfd87..eba8cc23b7e1 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1528,14 +1528,16 @@ def test_fn(storage_dtype, compute_dtype): test_fn(torch.float8_e5m2, torch.float32) test_fn(torch.float8_e4m3fn, torch.bfloat16) + @torch.no_grad() def test_layerwise_casting_inference(self): from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**config).eval() - model = model.to(torch_device) - base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + model = self.model_class(**config) + model.eval() + model.to(torch_device) + base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy() def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN @@ -1573,6 +1575,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype): test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16) @require_torch_accelerator + @torch.no_grad() def test_layerwise_casting_memory(self): MB_TOLERANCE = 0.2 LEAST_COMPUTE_CAPABILITY = 8.0 @@ -1706,10 +1709,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): if not self.model_class._supports_group_offloading: pytest.skip("Model does not support group offloading.") - torch.manual_seed(0) - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - torch.manual_seed(0) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -1725,7 +1724,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type): **additional_kwargs, ) has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") - assert has_safetensors, "No safetensors found in the directory." + self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.") _ = model(**inputs_dict)[0] def test_auto_model(self, expected_max_diff=5e-5):