|
13 | 13 | import torch.nn as nn |
14 | 14 | from huggingface_hub import ModelCard, delete_repo |
15 | 15 | from huggingface_hub.utils import is_jinja_available |
| 16 | +from torch._dynamo.utils import counters |
16 | 17 | from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer |
17 | 18 |
|
18 | 19 | import diffusers |
|
45 | 46 | from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor |
46 | 47 | from diffusers.utils.testing_utils import ( |
47 | 48 | CaptureLogger, |
| 49 | + backend_empty_cache, |
48 | 50 | require_accelerate_version_greater, |
49 | 51 | require_accelerator, |
50 | 52 | require_hf_hub_version_greater, |
51 | 53 | require_torch, |
52 | 54 | require_torch_gpu, |
53 | 55 | require_transformers_version_greater, |
54 | 56 | skip_mps, |
| 57 | + slow, |
55 | 58 | torch_device, |
56 | 59 | ) |
57 | 60 |
|
@@ -1113,8 +1116,9 @@ def setUp(self): |
1113 | 1116 | def tearDown(self): |
1114 | 1117 | # clean up the VRAM after each test in case of CUDA runtime errors |
1115 | 1118 | super().tearDown() |
| 1119 | + torch._dynamo.reset() |
1116 | 1120 | gc.collect() |
1117 | | - torch.cuda.empty_cache() |
| 1121 | + backend_empty_cache(torch_device) |
1118 | 1122 |
|
1119 | 1123 | def test_save_load_local(self, expected_max_difference=5e-4): |
1120 | 1124 | components = self.get_dummy_components() |
@@ -2153,6 +2157,41 @@ def test_StableDiffusionMixin_component(self): |
2153 | 2157 | ) |
2154 | 2158 | ) |
2155 | 2159 |
|
| 2160 | + @require_torch_gpu |
| 2161 | + @slow |
| 2162 | + def test_torch_compile_recompilation(self): |
| 2163 | + inputs = self.get_dummy_inputs() |
| 2164 | + components = self.get_dummy_components() |
| 2165 | + |
| 2166 | + pipe = self.pipeline_class(**components).to(torch_device) |
| 2167 | + if getattr(pipe, "unet", None) is None: |
| 2168 | + pipe.unet = torch.compile(pipe.unet, fullgraph=True) |
| 2169 | + else: |
| 2170 | + pipe.transformer = torch.compile(pipe.transformer, fullgraph=True) |
| 2171 | + |
| 2172 | + with torch._dynamo.config.patch(error_on_recompile=True): |
| 2173 | + _ = pipe(**inputs) |
| 2174 | + |
| 2175 | + @require_torch_gpu |
| 2176 | + @slow |
| 2177 | + def test_torch_compile_graph_breaks(self): |
| 2178 | + # Inspired by: |
| 2179 | + # https://github.com/pytorch/pytorch/blob/916e8979d3e0d651a9091732ce3e59da32e72b0e/test/dynamo/test_higher_order_ops.py#L138 |
| 2180 | + counters.clear() |
| 2181 | + |
| 2182 | + inputs = self.get_dummy_inputs() |
| 2183 | + components = self.get_dummy_components() |
| 2184 | + |
| 2185 | + pipe = self.pipeline_class(**components).to(torch_device) |
| 2186 | + if getattr(pipe, "unet", None) is None: |
| 2187 | + pipe.unet = torch.compile(pipe.unet, fullgraph=True) |
| 2188 | + else: |
| 2189 | + pipe.transformer = torch.compile(pipe.transformer, fullgraph=True) |
| 2190 | + |
| 2191 | + _ = pipe(**inputs) |
| 2192 | + num_graph_breaks = len(counters["graph_break"].keys()) |
| 2193 | + self.assertEqual(num_graph_breaks, 0) |
| 2194 | + |
2156 | 2195 | @require_hf_hub_version_greater("0.26.5") |
2157 | 2196 | @require_transformers_version_greater("4.47.1") |
2158 | 2197 | def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): |
|
0 commit comments