diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 757cbd65c6bf..c78d72dc4a2c 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -202,8 +202,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - self.freqs = self.freqs.to(hidden_states.device) - freqs = self.freqs.split_with_sizes( + freqs = self.freqs.to(hidden_states.device) + freqs = freqs.split_with_sizes( [ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), self.attention_head_dim // 6, diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index 3ac64c628988..8270c2ee21b0 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -17,7 +17,14 @@ import torch from diffusers import WanTransformer3DModel -from diffusers.utils.testing_utils import enable_full_determinism, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + is_torch_compile, + require_torch_2, + require_torch_gpu, + slow, + torch_device, +) from ..test_modeling_common import ModelTesterMixin @@ -79,3 +86,18 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"WanTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @require_torch_gpu + @require_torch_2 + @is_torch_compile + @slow + def test_torch_compile_recompilation_and_graph_break(self): + torch._dynamo.reset() + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict).to(torch_device) + model = torch.compile(model, fullgraph=True) + + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict)