Skip to content

Commit 5ad508f

Browse files
committed
fix conflicts.
2 parents 95c0b52 + aa5f5d4 commit 5ad508f

File tree

6 files changed

+101
-4
lines changed

6 files changed

+101
-4
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,55 @@ jobs:
181181
pip install slack_sdk tabulate
182182
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
183183
184+
run_torch_compile_tests:
185+
name: PyTorch Compile CUDA tests
186+
187+
runs-on:
188+
group: aws-g4dn-2xlarge
189+
190+
container:
191+
image: diffusers/diffusers-pytorch-compile-cuda
192+
options: --gpus 0 --shm-size "16gb" --ipc host
193+
194+
steps:
195+
- name: Checkout diffusers
196+
uses: actions/checkout@v3
197+
with:
198+
fetch-depth: 2
199+
200+
- name: NVIDIA-SMI
201+
run: |
202+
nvidia-smi
203+
- name: Install dependencies
204+
run: |
205+
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
206+
python -m uv pip install -e [quality,test,training]
207+
- name: Environment
208+
run: |
209+
python utils/print_env.py
210+
- name: Run torch compile tests on GPU
211+
env:
212+
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
213+
RUN_COMPILE: yes
214+
run: |
215+
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
216+
- name: Failure short reports
217+
if: ${{ failure() }}
218+
run: cat reports/tests_torch_compile_cuda_failures_short.txt
219+
220+
- name: Test suite reports artifacts
221+
if: ${{ always() }}
222+
uses: actions/upload-artifact@v4
223+
with:
224+
name: torch_compile_test_reports
225+
path: reports
226+
227+
- name: Generate Report and Notify Channel
228+
if: always()
229+
run: |
230+
pip install slack_sdk tabulate
231+
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
232+
184233
run_big_gpu_torch_tests:
185234
name: Torch tests on big GPU
186235
strategy:

.github/workflows/release_tests_fast.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ jobs:
335335
- name: Environment
336336
run: |
337337
python utils/print_env.py
338-
- name: Run example tests on GPU
338+
- name: Run torch compile tests on GPU
339339
env:
340340
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
341341
RUN_COMPILE: yes

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,15 @@ def parse_args(input_args=None):
524524
default=4,
525525
help=("The dimension of the LoRA update matrices."),
526526
)
527+
parser.add_argument(
528+
"--image_interpolation_mode",
529+
type=str,
530+
default="lanczos",
531+
choices=[
532+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
533+
],
534+
help="The image interpolation method to use for resizing images.",
535+
)
527536

528537
if input_args is not None:
529538
args = parser.parse_args(input_args)
@@ -601,9 +610,13 @@ def __init__(
601610
else:
602611
self.class_data_root = None
603612

613+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
614+
if interpolation is None:
615+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
616+
604617
self.image_transforms = transforms.Compose(
605618
[
606-
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
619+
transforms.Resize(size, interpolation=interpolation),
607620
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
608621
transforms.ToTensor(),
609622
transforms.Normalize([0.5], [0.5]),

tests/models/test_modeling_common.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,37 @@ def test_push_to_hub_library_name(self):
17131713
delete_repo(self.repo_id, token=TOKEN)
17141714

17151715

1716+
class TorchCompileTesterMixin:
1717+
def setUp(self):
1718+
# clean up the VRAM before each test
1719+
super().setUp()
1720+
torch._dynamo.reset()
1721+
gc.collect()
1722+
backend_empty_cache(torch_device)
1723+
1724+
def tearDown(self):
1725+
# clean up the VRAM after each test in case of CUDA runtime errors
1726+
super().tearDown()
1727+
torch._dynamo.reset()
1728+
gc.collect()
1729+
backend_empty_cache(torch_device)
1730+
1731+
@require_torch_gpu
1732+
@require_torch_2
1733+
@is_torch_compile
1734+
@slow
1735+
def test_torch_compile_recompilation_and_graph_break(self):
1736+
torch._dynamo.reset()
1737+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1738+
1739+
model = self.model_class(**init_dict).to(torch_device)
1740+
model = torch.compile(model, fullgraph=True)
1741+
1742+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
1743+
_ = model(**inputs_dict)
1744+
_ = model(**inputs_dict)
1745+
1746+
17161747
@slow
17171748
@require_torch_2
17181749
@require_torch_accelerator

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers.models.embeddings import ImageProjection
2323
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
2424

25-
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin
25+
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
2626

2727

2828
enable_full_determinism()
@@ -78,7 +78,9 @@ def create_flux_ip_adapter_state_dict(model):
7878
return ip_state_dict
7979

8080

81-
class FluxTransformerTests(ModelTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase):
81+
class FluxTransformerTests(
82+
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
83+
):
8284
model_class = FluxTransformer2DModel
8385
main_input_name = "hidden_states"
8486
# We override the items here because the transformer under consideration is small.

tests/pipelines/test_pipelines_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,12 +1111,14 @@ def callback_cfg_params(self) -> frozenset:
11111111
def setUp(self):
11121112
# clean up the VRAM before each test
11131113
super().setUp()
1114+
torch._dynamo.reset()
11141115
gc.collect()
11151116
backend_empty_cache(torch_device)
11161117

11171118
def tearDown(self):
11181119
# clean up the VRAM after each test in case of CUDA runtime errors
11191120
super().tearDown()
1121+
torch._dynamo.reset()
11201122
gc.collect()
11211123
backend_empty_cache(torch_device)
11221124

0 commit comments

Comments
 (0)