Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from diffusers import HunyuanVideoTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, require_torch_2, is_torch_compile, slow

from ..test_modeling_common import ModelTesterMixin

Expand Down Expand Up @@ -88,6 +89,25 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done in a separate PR.


@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)


class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -157,6 +177,21 @@ def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)


class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
Expand Down Expand Up @@ -222,6 +257,21 @@ def test_output(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)


class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -286,7 +336,22 @@ def prepare_init_args_and_inputs_for_common(self):

def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))

def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)