Skip to content

Commit e6f35ce

Browse files
update files for flux finetuning and generation. (#163)
* update files for flux finetuning and generation. * update readme. --------- Co-authored-by: Juan Acevedo <[email protected]>
1 parent b951454 commit e6f35ce

File tree

6 files changed

+36
-6
lines changed

6 files changed

+36
-6
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/04/17`**: Flux Finetuning.
2021
- **`2025/02/12`**: Flux LoRA for inference.
2122
- **`2025/02/08`**: Flux schnell & dev inference.
2223
- **`2024/12/12`**: Load multiple LoRAs for inference.
@@ -76,6 +77,26 @@ For your first time running Maxdiffusion, we provide specific [instructions](doc
7677

7778
After installation completes, run the training script.
7879

80+
- **Flux**
81+
82+
Expected results on 1024 x 1024 images with flash attention and bfloat16:
83+
84+
| Model | Accelerator | Sharding Strategy | Per Device Batch Size | Global Batch Size | Step Time (secs) |
85+
| --- | --- | --- | --- | --- | --- |
86+
| Flux-dev | v5p-8 | DDP | 1 | 4 | 1.31 |
87+
88+
Flux finetuning has only been tested on TPU v5p.
89+
90+
```bash
91+
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs://<your-gcs-bucket>/" save_final_checkpoint=True jax_cache_dir="/tmp/jax_cache"
92+
```
93+
94+
To generate images with a finetuned checkpoint, run:
95+
96+
```bash
97+
python src/maxdiffusion/generate_flux_pipeline.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs://<your-gcs-bucket>/" jax_cache_dir="/tmp/jax_cache"
98+
```
99+
79100
- **Stable Diffusion XL**
80101

81102
```bash

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def load_params_from_path(
143143

144144
ckpt_path = os.path.join(config.checkpoint_dir, str(step), checkpoint_item)
145145
ckpt_path = epath.Path(ckpt_path)
146-
ckpt_path = os.path.abspath(ckpt_path)
146+
if not ckpt_path.as_uri().startswith("gs://"):
147+
ckpt_path = os.path.abspath(ckpt_path)
147148

148149
restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params)
149150
restored = ckptr.restore(

src/maxdiffusion/configs/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,12 @@ base_2_base.yml - used for training and inference using [stable-diffusion-2-base
1212

1313
## Stable Diffusion XL & SDXL Lightning
1414

15-
base_xl.yml - used to run inference using [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
15+
base_xl.yml - used to run inference using [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
16+
17+
base_xl_lightning.yml - used to run inference using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
18+
19+
## Flux
20+
21+
base_flux_dev.yml - used for training and inference using [Flux Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
22+
23+
base_flux_schnell.yml - used for training and inference using [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell)

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ hf_train_files: ''
177177
hf_access_token: ''
178178
image_column: 'image'
179179
caption_column: 'text'
180-
resolution: 512
180+
resolution: 1024
181181
center_crop: False
182182
random_flip: False
183183
# If cache_latents_text_encoder_outputs is True

src/maxdiffusion/generate_flux_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def run(config):
9696

9797
t0 = time.perf_counter()
9898
with ExitStack():
99-
imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready()
99+
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
100100
t1 = time.perf_counter()
101101
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
102102

103103
t0 = time.perf_counter()
104104
with ExitStack():
105-
imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready()
105+
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
106106
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
107107
t1 = time.perf_counter()
108108
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")

src/maxdiffusion/pipelines/flux/flux_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def vae_decode(self, latents, vae, state, config):
102102
return img
103103

104104
def vae_encode(self, latents, vae, state):
105-
img = vae.apply({"params": state["params"]}, latents, deterministic=True, method=vae.encode).latent_dist.mode()
105+
img = vae.apply({"params": state.params}, latents, deterministic=True, method=vae.encode).latent_dist.mode()
106106
img = vae.config.scaling_factor * (img - vae.config.shift_factor)
107107
return img
108108

0 commit comments

Comments
 (0)