Skip to content

Commit 1b4ba9a

Browse files
committed
update
1 parent 35ee819 commit 1b4ba9a

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tests/models/transformers/test_models_transformer_hunyuan_video.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@
1717
import torch
1818

1919
from diffusers import HunyuanVideoTransformer3DModel
20-
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
21-
from diffusers.utils.testing_utils import require_torch_gpu, require_torch_2, is_torch_compile, slow
20+
from diffusers.utils.testing_utils import (
21+
enable_full_determinism,
22+
is_torch_compile,
23+
require_torch_2,
24+
require_torch_gpu,
25+
slow,
26+
torch_device,
27+
)
2228

2329
from ..test_modeling_common import ModelTesterMixin
2430

@@ -86,18 +92,14 @@ def prepare_init_args_and_inputs_for_common(self):
8692
inputs_dict = self.dummy_input
8793
return init_dict, inputs_dict
8894

89-
def test_gradient_checkpointing_is_applied(self):
90-
expected_set = {"HunyuanVideoTransformer3DModel"}
91-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
92-
9395
def test_gradient_checkpointing_is_applied(self):
9496
expected_set = {"HunyuanVideoTransformer3DModel"}
9597
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
9698

9799
@require_torch_gpu
98100
@require_torch_2
99101
@is_torch_compile
100-
@slow
102+
@slow
101103
def test_torch_compile_recompilation_and_graph_break(self):
102104
torch._dynamo.reset()
103105
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -191,7 +193,7 @@ def test_torch_compile_recompilation_and_graph_break(self):
191193
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
192194
_ = model(**inputs_dict)
193195
_ = model(**inputs_dict)
194-
196+
195197

196198
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
197199
model_class = HunyuanVideoTransformer3DModel
@@ -257,7 +259,7 @@ def test_output(self):
257259
def test_gradient_checkpointing_is_applied(self):
258260
expected_set = {"HunyuanVideoTransformer3DModel"}
259261
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
260-
262+
261263
@require_torch_gpu
262264
@require_torch_2
263265
@is_torch_compile
@@ -336,7 +338,7 @@ def prepare_init_args_and_inputs_for_common(self):
336338

337339
def test_output(self):
338340
super().test_output(expected_output_shape=(1, *self.output_shape))
339-
341+
340342
def test_gradient_checkpointing_is_applied(self):
341343
expected_set = {"HunyuanVideoTransformer3DModel"}
342344
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)