Skip to content

Commit de30cba

Browse files
committed
test for better torch.compile stuff.
1 parent 1001425 commit de30cba

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn as nn
1414
from huggingface_hub import ModelCard, delete_repo
1515
from huggingface_hub.utils import is_jinja_available
16+
from torch._dynamo.utils import counters
1617
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
1718

1819
import diffusers
@@ -45,13 +46,15 @@
4546
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
4647
from diffusers.utils.testing_utils import (
4748
CaptureLogger,
49+
backend_empty_cache,
4850
require_accelerate_version_greater,
4951
require_accelerator,
5052
require_hf_hub_version_greater,
5153
require_torch,
5254
require_torch_gpu,
5355
require_transformers_version_greater,
5456
skip_mps,
57+
slow,
5558
torch_device,
5659
)
5760

@@ -1113,8 +1116,9 @@ def setUp(self):
11131116
def tearDown(self):
11141117
# clean up the VRAM after each test in case of CUDA runtime errors
11151118
super().tearDown()
1119+
torch._dynamo.reset()
11161120
gc.collect()
1117-
torch.cuda.empty_cache()
1121+
backend_empty_cache(torch_device)
11181122

11191123
def test_save_load_local(self, expected_max_difference=5e-4):
11201124
components = self.get_dummy_components()
@@ -2153,6 +2157,41 @@ def test_StableDiffusionMixin_component(self):
21532157
)
21542158
)
21552159

2160+
@require_torch_gpu
2161+
@slow
2162+
def test_torch_compile_recompilation(self):
2163+
inputs = self.get_dummy_inputs()
2164+
components = self.get_dummy_components()
2165+
2166+
pipe = self.pipeline_class(**components).to(torch_device)
2167+
if getattr(pipe, "unet", None) is None:
2168+
pipe.unet = torch.compile(pipe.unet, fullgraph=True)
2169+
else:
2170+
pipe.transformer = torch.compile(pipe.transformer, fullgraph=True)
2171+
2172+
with torch._dynamo.config.patch(error_on_recompile=True):
2173+
_ = pipe(**inputs)
2174+
2175+
@require_torch_gpu
2176+
@slow
2177+
def test_torch_compile_graph_breaks(self):
2178+
# Inspired by:
2179+
# https://github.com/pytorch/pytorch/blob/916e8979d3e0d651a9091732ce3e59da32e72b0e/test/dynamo/test_higher_order_ops.py#L138
2180+
counters.clear()
2181+
2182+
inputs = self.get_dummy_inputs()
2183+
components = self.get_dummy_components()
2184+
2185+
pipe = self.pipeline_class(**components).to(torch_device)
2186+
if getattr(pipe, "unet", None) is None:
2187+
pipe.unet = torch.compile(pipe.unet, fullgraph=True)
2188+
else:
2189+
pipe.transformer = torch.compile(pipe.transformer, fullgraph=True)
2190+
2191+
_ = pipe(**inputs)
2192+
num_graph_breaks = len(counters["graph_break"].keys())
2193+
self.assertEqual(num_graph_breaks, 0)
2194+
21562195
@require_hf_hub_version_greater("0.26.5")
21572196
@require_transformers_version_greater("4.47.1")
21582197
def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):

0 commit comments

Comments
 (0)