Skip to content

Commit e0566e6

Browse files
committed
change to modeling level test.
1 parent 1a934b2 commit e0566e6

File tree

3 files changed

+31
-19
lines changed

3 files changed

+31
-19
lines changed

tests/models/test_modeling_common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,35 @@ def test_push_to_hub_library_name(self):
17141714
delete_repo(self.repo_id, token=TOKEN)
17151715

17161716

1717+
class TorchCompileTesterMixin:
1718+
def setUp(self):
1719+
# clean up the VRAM before each test
1720+
super().setUp()
1721+
torch._dynamo.reset()
1722+
gc.collect()
1723+
backend_empty_cache(torch_device)
1724+
1725+
def tearDown(self):
1726+
# clean up the VRAM after each test in case of CUDA runtime errors
1727+
super().tearDown()
1728+
torch._dynamo.reset()
1729+
gc.collect()
1730+
backend_empty_cache(torch_device)
1731+
1732+
@require_torch_gpu
1733+
@require_torch_2
1734+
@slow
1735+
def test_torch_compile_recompilation_and_graph_break(self):
1736+
torch._dynamo.reset()
1737+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1738+
1739+
model = self.model_class(**init_dict).to(torch_device)
1740+
model = torch.compile(model, fullgraph=True)
1741+
1742+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
1743+
_ = model(**inputs_dict)
1744+
1745+
17171746
@slow
17181747
@require_torch_2
17191748
@require_torch_accelerator

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers.models.embeddings import ImageProjection
2323
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
2424

25-
from ..test_modeling_common import ModelTesterMixin
25+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
2626

2727

2828
enable_full_determinism()
@@ -78,7 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
7878
return ip_state_dict
7979

8080

81-
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
81+
class FluxTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
8282
model_class = FluxTransformer2DModel
8383
main_input_name = "hidden_states"
8484
# We override the items here because the transformer under consideration is small.

tests/pipelines/test_pipelines_common.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
require_torch_gpu,
5757
require_transformers_version_greater,
5858
skip_mps,
59-
slow,
6059
torch_device,
6160
)
6261

@@ -2165,22 +2164,6 @@ def test_StableDiffusionMixin_component(self):
21652164
)
21662165
)
21672166

2168-
@require_torch_gpu
2169-
@slow
2170-
def test_torch_compile_recompilation_and_graph_break(self):
2171-
torch._dynamo.reset()
2172-
inputs = self.get_dummy_inputs(torch_device)
2173-
components = self.get_dummy_components()
2174-
2175-
pipe = self.pipeline_class(**components).to(torch_device)
2176-
if getattr(pipe, "unet", None) is not None:
2177-
pipe.unet = torch.compile(pipe.unet, fullgraph=True)
2178-
else:
2179-
pipe.transformer = torch.compile(pipe.transformer, fullgraph=True)
2180-
2181-
with torch._dynamo.config.patch(error_on_recompile=True):
2182-
_ = pipe(**inputs)
2183-
21842167
@require_hf_hub_version_greater("0.26.5")
21852168
@require_transformers_version_greater("4.47.1")
21862169
def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):

0 commit comments

Comments
 (0)