Skip to content

Commit 10b7f27

Browse files
authored
Merge branch 'main' into sana
2 parents 8b00756 + 5374821 commit 10b7f27

19 files changed

+86
-534
lines changed

.github/workflows/benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
runs-on:
2424
group: aws-g6-4xlarge-plus
2525
container:
26-
image: diffusers/diffusers-pytorch-compile-cuda
26+
image: diffusers/diffusers-pytorch-cuda
2727
options: --shm-size "16gb" --ipc host --gpus 0
2828
steps:
2929
- name: Checkout diffusers

.github/workflows/build_docker_images.yml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,16 @@ jobs:
3838
token: ${{ secrets.GITHUB_TOKEN }}
3939

4040
- name: Build Changed Docker Images
41+
env:
42+
CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
4143
run: |
42-
CHANGED_FILES="${{ steps.file_changes.outputs.all }}"
43-
for FILE in $CHANGED_FILES; do
44+
echo "$CHANGED_FILES"
45+
for FILE in $CHANGED_FILES; do
46+
# skip anything that isn't still on disk
47+
if [[ ! -f "$FILE" ]]; then
48+
echo "Skipping removed file $FILE"
49+
continue
50+
fi
4451
if [[ "$FILE" == docker/*Dockerfile ]]; then
4552
DOCKER_PATH="${FILE%/Dockerfile}"
4653
DOCKER_TAG=$(basename "$DOCKER_PATH")
@@ -65,7 +72,7 @@ jobs:
6572
image-name:
6673
- diffusers-pytorch-cpu
6774
- diffusers-pytorch-cuda
68-
- diffusers-pytorch-compile-cuda
75+
- diffusers-pytorch-cuda
6976
- diffusers-pytorch-xformers-cuda
7077
- diffusers-pytorch-minimum-cuda
7178
- diffusers-flax-cpu

.github/workflows/nightly_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ jobs:
188188
group: aws-g4dn-2xlarge
189189

190190
container:
191-
image: diffusers/diffusers-pytorch-compile-cuda
191+
image: diffusers/diffusers-pytorch-cuda
192192
options: --gpus 0 --shm-size "16gb" --ipc host
193193

194194
steps:

.github/workflows/push_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ jobs:
262262
group: aws-g4dn-2xlarge
263263

264264
container:
265-
image: diffusers/diffusers-pytorch-compile-cuda
265+
image: diffusers/diffusers-pytorch-cuda
266266
options: --gpus 0 --shm-size "16gb" --ipc host
267267

268268
steps:

.github/workflows/release_tests_fast.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ jobs:
316316
group: aws-g4dn-2xlarge
317317

318318
container:
319-
image: diffusers/diffusers-pytorch-compile-cuda
319+
image: diffusers/diffusers-pytorch-cuda
320320
options: --gpus 0 --shm-size "16gb" --ipc host
321321

322322
steps:

docker/diffusers-pytorch-compile-cuda/Dockerfile

Lines changed: 0 additions & 50 deletions
This file was deleted.

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
146146
components[component].append(component_filename)
147147

148148
# If there are no component folders check the main directory for safetensors files
149+
filtered_filenames = set()
149150
if not components:
150151
if variant is not None:
151152
filtered_filenames = filter_with_regex(filenames, variant_file_re)
152-
else:
153+
154+
# If no variant filenames exist check if non-variant files are available
155+
if not filtered_filenames:
153156
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
154157
return any(".safetensors" in filename for filename in filtered_filenames)
155158

156159
# iterate over all files of a component
157160
# check if safetensor files exist for that component
158-
# if variant is provided check if the variant of the safetensors exists
159161
for component, component_filenames in components.items():
160162
matches = []
163+
filtered_component_filenames = set()
164+
# if variant is provided check if the variant of the safetensors exists
161165
if variant is not None:
162166
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
163-
else:
167+
168+
# if variant safetensor files do not exist check for non-variants
169+
if not filtered_component_filenames:
164170
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
165171
for component_filename in filtered_component_filenames:
166172
filename, extension = os.path.splitext(component_filename)

tests/models/test_modeling_common.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,14 +1748,14 @@ class TorchCompileTesterMixin:
17481748
def setUp(self):
17491749
# clean up the VRAM before each test
17501750
super().setUp()
1751-
torch._dynamo.reset()
1751+
torch.compiler.reset()
17521752
gc.collect()
17531753
backend_empty_cache(torch_device)
17541754

17551755
def tearDown(self):
17561756
# clean up the VRAM after each test in case of CUDA runtime errors
17571757
super().tearDown()
1758-
torch._dynamo.reset()
1758+
torch.compiler.reset()
17591759
gc.collect()
17601760
backend_empty_cache(torch_device)
17611761

@@ -1764,13 +1764,17 @@ def tearDown(self):
17641764
@is_torch_compile
17651765
@slow
17661766
def test_torch_compile_recompilation_and_graph_break(self):
1767-
torch._dynamo.reset()
1767+
torch.compiler.reset()
17681768
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17691769

17701770
model = self.model_class(**init_dict).to(torch_device)
17711771
model = torch.compile(model, fullgraph=True)
17721772

1773-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
1773+
with (
1774+
torch._inductor.utils.fresh_inductor_cache(),
1775+
torch._dynamo.config.patch(error_on_recompile=True),
1776+
torch.no_grad(),
1777+
):
17741778
_ = model(**inputs_dict)
17751779
_ = model(**inputs_dict)
17761780

@@ -1798,7 +1802,7 @@ def tearDown(self):
17981802
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
17991803
# there will be recompilation errors, as torch caches the model when run in the same process.
18001804
super().tearDown()
1801-
torch._dynamo.reset()
1805+
torch.compiler.reset()
18021806
gc.collect()
18031807
backend_empty_cache(torch_device)
18041808

@@ -1915,7 +1919,7 @@ def test_hotswapping_model(self, rank0, rank1):
19151919
def test_hotswapping_compiled_model_linear(self, rank0, rank1):
19161920
# It's important to add this context to raise an error on recompilation
19171921
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
1918-
with torch._dynamo.config.patch(error_on_recompile=True):
1922+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
19191923
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
19201924

19211925
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
@@ -1925,7 +1929,7 @@ def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
19251929

19261930
# It's important to add this context to raise an error on recompilation
19271931
target_modules = ["conv", "conv1", "conv2"]
1928-
with torch._dynamo.config.patch(error_on_recompile=True):
1932+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
19291933
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
19301934

19311935
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
@@ -1935,7 +1939,7 @@ def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
19351939

19361940
# It's important to add this context to raise an error on recompilation
19371941
target_modules = ["to_q", "conv"]
1938-
with torch._dynamo.config.patch(error_on_recompile=True):
1942+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
19391943
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
19401944

19411945
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa

tests/models/transformers/test_models_transformer_hunyuan_video.py

Lines changed: 7 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,16 @@
1919
from diffusers import HunyuanVideoTransformer3DModel
2020
from diffusers.utils.testing_utils import (
2121
enable_full_determinism,
22-
is_torch_compile,
23-
require_torch_2,
24-
require_torch_gpu,
25-
slow,
2622
torch_device,
2723
)
2824

29-
from ..test_modeling_common import ModelTesterMixin
25+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
3026

3127

3228
enable_full_determinism()
3329

3430

35-
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
31+
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
3632
model_class = HunyuanVideoTransformer3DModel
3733
main_input_name = "hidden_states"
3834
uses_custom_attn_processor = True
@@ -96,23 +92,8 @@ def test_gradient_checkpointing_is_applied(self):
9692
expected_set = {"HunyuanVideoTransformer3DModel"}
9793
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
9894

99-
@require_torch_gpu
100-
@require_torch_2
101-
@is_torch_compile
102-
@slow
103-
def test_torch_compile_recompilation_and_graph_break(self):
104-
torch._dynamo.reset()
105-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10695

107-
model = self.model_class(**init_dict).to(torch_device)
108-
model = torch.compile(model, fullgraph=True)
109-
110-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
111-
_ = model(**inputs_dict)
112-
_ = model(**inputs_dict)
113-
114-
115-
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
96+
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
11697
model_class = HunyuanVideoTransformer3DModel
11798
main_input_name = "hidden_states"
11899
uses_custom_attn_processor = True
@@ -179,23 +160,8 @@ def test_gradient_checkpointing_is_applied(self):
179160
expected_set = {"HunyuanVideoTransformer3DModel"}
180161
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
181162

182-
@require_torch_gpu
183-
@require_torch_2
184-
@is_torch_compile
185-
@slow
186-
def test_torch_compile_recompilation_and_graph_break(self):
187-
torch._dynamo.reset()
188-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
189-
190-
model = self.model_class(**init_dict).to(torch_device)
191-
model = torch.compile(model, fullgraph=True)
192-
193-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
194-
_ = model(**inputs_dict)
195-
_ = model(**inputs_dict)
196-
197163

198-
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
164+
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
199165
model_class = HunyuanVideoTransformer3DModel
200166
main_input_name = "hidden_states"
201167
uses_custom_attn_processor = True
@@ -260,23 +226,10 @@ def test_gradient_checkpointing_is_applied(self):
260226
expected_set = {"HunyuanVideoTransformer3DModel"}
261227
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
262228

263-
@require_torch_gpu
264-
@require_torch_2
265-
@is_torch_compile
266-
@slow
267-
def test_torch_compile_recompilation_and_graph_break(self):
268-
torch._dynamo.reset()
269-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
270229

271-
model = self.model_class(**init_dict).to(torch_device)
272-
model = torch.compile(model, fullgraph=True)
273-
274-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
275-
_ = model(**inputs_dict)
276-
_ = model(**inputs_dict)
277-
278-
279-
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
230+
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
231+
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
232+
):
280233
model_class = HunyuanVideoTransformer3DModel
281234
main_input_name = "hidden_states"
282235
uses_custom_attn_processor = True
@@ -342,18 +295,3 @@ def test_output(self):
342295
def test_gradient_checkpointing_is_applied(self):
343296
expected_set = {"HunyuanVideoTransformer3DModel"}
344297
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
345-
346-
@require_torch_gpu
347-
@require_torch_2
348-
@is_torch_compile
349-
@slow
350-
def test_torch_compile_recompilation_and_graph_break(self):
351-
torch._dynamo.reset()
352-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
353-
354-
model = self.model_class(**init_dict).to(torch_device)
355-
model = torch.compile(model, fullgraph=True)
356-
357-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
358-
_ = model(**inputs_dict)
359-
_ = model(**inputs_dict)

tests/models/transformers/test_models_transformer_wan.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,16 @@
1919
from diffusers import WanTransformer3DModel
2020
from diffusers.utils.testing_utils import (
2121
enable_full_determinism,
22-
is_torch_compile,
23-
require_torch_2,
24-
require_torch_gpu,
25-
slow,
2622
torch_device,
2723
)
2824

29-
from ..test_modeling_common import ModelTesterMixin
25+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
3026

3127

3228
enable_full_determinism()
3329

3430

35-
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
31+
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
3632
model_class = WanTransformer3DModel
3733
main_input_name = "hidden_states"
3834
uses_custom_attn_processor = True
@@ -86,18 +82,3 @@ def prepare_init_args_and_inputs_for_common(self):
8682
def test_gradient_checkpointing_is_applied(self):
8783
expected_set = {"WanTransformer3DModel"}
8884
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
89-
90-
@require_torch_gpu
91-
@require_torch_2
92-
@is_torch_compile
93-
@slow
94-
def test_torch_compile_recompilation_and_graph_break(self):
95-
torch._dynamo.reset()
96-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
97-
98-
model = self.model_class(**init_dict).to(torch_device)
99-
model = torch.compile(model, fullgraph=True)
100-
101-
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
102-
_ = model(**inputs_dict)
103-
_ = model(**inputs_dict)

0 commit comments

Comments
 (0)