Skip to content

Commit 99c7696

Browse files
authored
Merge branch 'main' into lora-tests-cleanup
2 parents 2b0a7f0 + 534848c commit 99c7696

File tree

7 files changed

+2023
-15
lines changed

7 files changed

+2023
-15
lines changed

examples/controlnet/README_flux.md

Lines changed: 430 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
accelerate>=0.16.0
2+
torchvision
3+
transformers>=4.25.1
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
datasets
8+
wandb
9+
SentencePiece

examples/controlnet/test_controlnet.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,28 @@ def test_controlnet_sd3(self):
136136
run_command(self._launch_args + test_args)
137137

138138
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
139+
140+
141+
class ControlNetflux(ExamplesTestsAccelerate):
142+
def test_controlnet_flux(self):
143+
with tempfile.TemporaryDirectory() as tmpdir:
144+
test_args = f"""
145+
examples/controlnet/train_controlnet_flux.py
146+
--pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
147+
--output_dir={tmpdir}
148+
--dataset_name=hf-internal-testing/fill10
149+
--conditioning_image_column=conditioning_image
150+
--image_column=image
151+
--caption_column=text
152+
--resolution=64
153+
--train_batch_size=1
154+
--gradient_accumulation_steps=1
155+
--max_train_steps=4
156+
--checkpointing_steps=2
157+
--num_double_layers=1
158+
--num_single_layers=1
159+
""".split()
160+
161+
run_command(self._launch_args + test_args)
162+
163+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))

0 commit comments

Comments
 (0)