From 29419af8791cf7193930338ee77156a2035c52bc Mon Sep 17 00:00:00 2001 From: tongyu0924 Date: Tue, 29 Apr 2025 22:05:33 +0800 Subject: [PATCH 1/3] [tests] Add torch.compile() test for WanTransformer3DModel --- .../transformers/test_models_transformer_wan.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index 3ac64c628988..b17483f2e3e1 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -20,6 +20,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin +from diffusers.utils.testing_utils import require_torch_gpu, require_torch_2, is_torch_compile, slow enable_full_determinism() @@ -79,3 +80,18 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"WanTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @require_torch_gpu + @require_torch_2 + @is_torch_compile + @slow + def test_torch_compile_recompilation_and_graph_break(self): + torch._dynamo.reset() + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict).to(torch_device) + model = torch.compile(model, fullgraph=True) + + with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) From e325cb15411562ce56023e3903aeb05f528090e7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 1 May 2025 10:53:57 +0530 Subject: [PATCH 2/3] fix wan recompilation issues. --- src/diffusers/models/transformers/transformer_wan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 757cbd65c6bf..c78d72dc4a2c 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -202,8 +202,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - self.freqs = self.freqs.to(hidden_states.device) - freqs = self.freqs.split_with_sizes( + freqs = self.freqs.to(hidden_states.device) + freqs = freqs.split_with_sizes( [ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), self.attention_head_dim // 6, From 1ff3dc1c9fc6de138eef0f9ad602f8342f68ae14 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 1 May 2025 10:59:18 +0530 Subject: [PATCH 3/3] style --- .../transformers/test_models_transformer_wan.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index b17483f2e3e1..8270c2ee21b0 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -17,10 +17,16 @@ import torch from diffusers import WanTransformer3DModel -from diffusers.utils.testing_utils import enable_full_determinism, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + is_torch_compile, + require_torch_2, + require_torch_gpu, + slow, + torch_device, +) from ..test_modeling_common import ModelTesterMixin -from diffusers.utils.testing_utils import require_torch_gpu, require_torch_2, is_torch_compile, slow enable_full_determinism() @@ -80,7 +86,7 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"WanTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - + @require_torch_gpu @require_torch_2 @is_torch_compile