Skip to content

Commit 269973f

Browse files
authored
Linnan flux fix (#104)
* add flux finetuning yaml Signed-off-by: linnan wang <linnanw@nvidia.com> * update Signed-off-by: linnan wang <linnanw@nvidia.com> * fix flux adapter guidance scale Signed-off-by: linnan wang <linnanw@nvidia.com> * add flux generation codes Signed-off-by: linnan wang <linnanw@nvidia.com> * add multi-resolution data loadere functional tests Signed-off-by: linnan wang <linnanw@nvidia.com> --------- Signed-off-by: linnan wang <linnanw@nvidia.com>
1 parent ce553b2 commit 269973f

File tree

5 files changed

+468
-18
lines changed

5 files changed

+468
-18
lines changed

dfm/src/automodel/flow_matching/adapters/flux.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]:
174174
# The pipeline provides timesteps in [0, num_train_timesteps]
175175
timesteps = context.timesteps.to(dtype) / 1000.0
176176

177-
guidance = torch.full((batch_size,), 3.5, device=device, dtype=torch.float32)
177+
# TODO: guidance scale is different across pretraining and finetuning, we need pass it as a hyperparamters.
178+
# needs verify by Pranav
179+
guidance = torch.full((batch_size,), self.guidance_scale, device=device, dtype=torch.float32)
178180

179181
inputs = {
180182
"hidden_states": packed_latents,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
model:
2+
pretrained_model_name_or_path: "black-forest-labs/FLUX.1-dev"
3+
mode: "finetune"
4+
cache_dir: null
5+
attention_backend: "_flash_3_hub"
6+
7+
pipeline_spec:
8+
transformer_cls: "FluxTransformer2DModel"
9+
subfolder: "transformer"
10+
load_full_pipeline: false
11+
enable_gradient_checkpointing: false
12+
13+
optim:
14+
learning_rate: 1e-5
15+
16+
optimizer:
17+
weight_decay: 0.01
18+
betas: [0.9, 0.999]
19+
20+
#adjust dp_size to the total number of GPUs
21+
fsdp:
22+
dp_size: 8
23+
tp_size: 1
24+
cp_size: 1
25+
pp_size: 1
26+
activation_checkpointing: false
27+
cpu_offload: false
28+
29+
flow_matching:
30+
adapter_type: "flux"
31+
adapter_kwargs:
32+
#Critical: use 3.5 guidance scale for finetuning
33+
guidance_scale: 3.5
34+
use_guidance_embeds: true
35+
timestep_sampling: "logit_normal"
36+
logit_mean: 0.0
37+
logit_std: 1.0
38+
flow_shift: 3.0
39+
mix_uniform_ratio: 0.1
40+
sigma_min: 0.0
41+
sigma_max: 1.0
42+
num_train_timesteps: 1000
43+
i2v_prob: 0.0
44+
use_loss_weighting: true
45+
log_interval: 100
46+
summary_log_interval: 10
47+
48+
step_scheduler:
49+
num_epochs: 5000
50+
local_batch_size: 1
51+
global_batch_size: 8
52+
ckpt_every_steps: 2000
53+
log_every: 1
54+
55+
data:
56+
dataloader:
57+
_target_: dfm.src.automodel.datasets.multiresolutionDataloader.build_flux_multiresolution_dataloader
58+
cache_dir: PATH_TO_YOUR_DATA
59+
train_text_encoder: false
60+
num_workers: 10
61+
# Supported resolutions include [256×256], [512×512], and [1024×1024].
62+
# While a 1:1 aspect ratio is currently used as a proxy for the closest image size,
63+
# the implementation is designed to support multiple aspect ratios.
64+
base_resolution: [512, 512]
65+
dynamic_batch_size: false
66+
shuffle: true
67+
drop_last: false
68+
69+
checkpoint:
70+
enabled: true
71+
checkpoint_dir: PATH_TO_YOUR_CKPT_DIR
72+
model_save_format: torch_save
73+
save_consolidated: false
74+
restore_from: null
75+
76+
wandb:
77+
project: flux-finetuning
78+
mode: online
79+
name: flux_pretrain_ddp_test_run_1
80+
81+
dist_env:
82+
backend: "nccl"
83+
init_method: "env://"
84+
85+
seed: 42
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
FLUX Inference Script with Multi-Resolution Dataloader (Embedding Injection)
17+
18+
This script loads a FLUX transformer and runs inference by extracting
19+
pre-computed text embeddings directly from the multiresolution dataloader.
20+
"""
21+
22+
import argparse
23+
import logging
24+
import os
25+
import random
26+
from pathlib import Path
27+
28+
import numpy as np
29+
import torch
30+
from diffusers import FluxPipeline
31+
32+
# Import the provided dataloader builder
33+
from dfm.src.automodel.datasets.multiresolutionDataloader import build_flux_multiresolution_dataloader
34+
35+
36+
def parse_args():
37+
parser = argparse.ArgumentParser(description="FLUX Inference with pre-computed embeddings")
38+
39+
parser.add_argument(
40+
"--model_id",
41+
type=str,
42+
default="black-forest-labs/FLUX.1-dev",
43+
help="Base FLUX model ID from HuggingFace",
44+
)
45+
parser.add_argument(
46+
"--checkpoint",
47+
type=str,
48+
default=None,
49+
help="Path to checkpoint directory containing model/ subfolder or consolidated weights",
50+
)
51+
parser.add_argument(
52+
"--use-original",
53+
action="store_true",
54+
help="Use original FLUX model without loading custom checkpoint",
55+
)
56+
parser.add_argument(
57+
"--data_path",
58+
type=str,
59+
required=True,
60+
help="Path to the dataset cache directory",
61+
)
62+
parser.add_argument(
63+
"--max_samples",
64+
type=int,
65+
default=5,
66+
help="Maximum number of images to generate",
67+
)
68+
parser.add_argument(
69+
"--output_dir",
70+
type=str,
71+
default="./inference_outputs",
72+
help="Directory to save generated images",
73+
)
74+
parser.add_argument(
75+
"--num_inference_steps",
76+
type=int,
77+
default=28,
78+
help="Number of inference steps",
79+
)
80+
parser.add_argument(
81+
"--guidance_scale",
82+
type=float,
83+
default=3.5,
84+
help="Guidance scale",
85+
)
86+
parser.add_argument(
87+
"--height",
88+
type=int,
89+
default=512,
90+
help="Image height",
91+
)
92+
parser.add_argument(
93+
"--width",
94+
type=int,
95+
default=512,
96+
help="Image width",
97+
)
98+
parser.add_argument(
99+
"--seed",
100+
type=int,
101+
default=42,
102+
help="Random seed",
103+
)
104+
parser.add_argument(
105+
"--dtype",
106+
type=str,
107+
default="bfloat16",
108+
choices=["float16", "bfloat16", "float32"],
109+
help="Data type for model",
110+
)
111+
parser.add_argument(
112+
"--num_workers",
113+
type=int,
114+
default=2,
115+
help="Number of workers for the dataloader",
116+
)
117+
118+
return parser.parse_args()
119+
120+
121+
def load_sharded_checkpoint(transformer, checkpoint_dir, device="cuda"):
122+
import torch.distributed as dist
123+
from torch.distributed.checkpoint import FileSystemReader
124+
from torch.distributed.checkpoint import load as dist_load
125+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
126+
from torch.distributed.fsdp import StateDictType
127+
from torch.distributed.fsdp.api import ShardedStateDictConfig
128+
129+
sharded_dir = os.path.join(checkpoint_dir, "model")
130+
if not os.path.isdir(sharded_dir):
131+
raise FileNotFoundError(f"Model directory not found: {sharded_dir}")
132+
133+
# Initialize a single-process group if needed
134+
init_dist = False
135+
if not dist.is_initialized():
136+
os.environ.setdefault("MASTER_ADDR", "localhost")
137+
os.environ.setdefault("MASTER_PORT", "29500")
138+
dist.init_process_group(backend="gloo", rank=0, world_size=1)
139+
init_dist = True
140+
141+
try:
142+
transformer = transformer.to(device=device, dtype=torch.bfloat16)
143+
fsdp_transformer = FSDP(transformer, use_orig_params=True, device_id=torch.device(device))
144+
145+
FSDP.set_state_dict_type(
146+
fsdp_transformer,
147+
StateDictType.SHARDED_STATE_DICT,
148+
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
149+
)
150+
151+
model_state = fsdp_transformer.state_dict()
152+
dist_load(state_dict=model_state, storage_reader=FileSystemReader(sharded_dir))
153+
fsdp_transformer.load_state_dict(model_state)
154+
transformer = fsdp_transformer.module
155+
print("[INFO] ✅ Successfully loaded sharded FSDP checkpoint")
156+
finally:
157+
if init_dist:
158+
dist.destroy_process_group()
159+
return transformer
160+
161+
162+
def load_consolidated_checkpoint(transformer, checkpoint_path):
163+
print(f"[INFO] Loading consolidated checkpoint from {checkpoint_path}")
164+
state_dict = torch.load(checkpoint_path, map_location="cpu")
165+
if "model_state_dict" in state_dict:
166+
state_dict = state_dict["model_state_dict"]
167+
transformer.load_state_dict(state_dict, strict=True)
168+
print("[INFO] ✅ Loaded consolidated checkpoint")
169+
return transformer
170+
171+
172+
def main():
173+
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
174+
args = parse_args()
175+
176+
if args.seed is not None:
177+
random.seed(args.seed)
178+
np.random.seed(args.seed)
179+
torch.manual_seed(args.seed)
180+
torch.cuda.manual_seed_all(args.seed)
181+
182+
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
183+
torch_dtype = dtype_map[args.dtype]
184+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185+
186+
# --- 1. Initialize Dataloader ---
187+
print("=" * 80)
188+
print(f"Initializing Multiresolution Dataloader: {args.data_path}")
189+
190+
dataloader, _ = build_flux_multiresolution_dataloader(
191+
cache_dir=args.data_path, batch_size=1, num_workers=args.num_workers, dynamic_batch_size=True, shuffle=False
192+
)
193+
print(f"[INFO] Dataloader ready. Batches: {len(dataloader)}")
194+
195+
# --- 2. Initialize Model ---
196+
use_original = args.use_original or args.checkpoint is None
197+
198+
print(f"\n[Pipeline] Loading FLUX pipeline from: {args.model_id}")
199+
pipe = FluxPipeline.from_pretrained(args.model_id, torch_dtype=torch_dtype)
200+
201+
if not use_original:
202+
checkpoint_dir = Path(args.checkpoint)
203+
model_name = checkpoint_dir.name
204+
sharded_dir = checkpoint_dir / "model"
205+
consolidated_path = checkpoint_dir / "consolidated_model.bin"
206+
ema_path = checkpoint_dir / "ema_shadow.pt"
207+
208+
if ema_path.exists():
209+
print("[INFO] Loading EMA checkpoint...")
210+
pipe.transformer.load_state_dict(torch.load(ema_path, map_location="cpu"))
211+
elif consolidated_path.exists():
212+
pipe.transformer = load_consolidated_checkpoint(pipe.transformer, str(consolidated_path))
213+
elif sharded_dir.exists():
214+
pipe.transformer = load_sharded_checkpoint(pipe.transformer, str(checkpoint_dir), device=device)
215+
else:
216+
model_name = "original"
217+
218+
pipe.enable_model_cpu_offload()
219+
output_dir = Path(args.output_dir)
220+
output_dir.mkdir(parents=True, exist_ok=True)
221+
222+
# --- 3. Inference Loop (Injecting Embeddings) ---
223+
print(f"\n[Inference] Generating {args.max_samples} samples using pre-computed embeddings...")
224+
generator = torch.Generator(device="cpu").manual_seed(args.seed)
225+
226+
count = 0
227+
for batch_idx, batch in enumerate(dataloader):
228+
if count >= args.max_samples:
229+
break
230+
231+
try:
232+
# Extract metadata for logging/filenames
233+
current_prompt = batch["metadata"]["prompts"][0]
234+
source_path = batch["metadata"]["image_paths"][0]
235+
236+
# Extract and move embeddings to device/dtype
237+
# batch['text_embeddings'] corresponds to T5 output
238+
# batch['pooled_prompt_embeds'] corresponds to CLIP pooled output
239+
prompt_embeds = batch["text_embeddings"].to(device=device, dtype=torch_dtype)
240+
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=device, dtype=torch_dtype)
241+
242+
except (KeyError, IndexError) as e:
243+
print(f"[WARN] Batch {batch_idx} missing required data. Skipping. Error: {e}")
244+
continue
245+
246+
print(f"\n--- Sample {count + 1}/{args.max_samples} ---")
247+
print(f" Source: {os.path.basename(source_path)}")
248+
print(f" Prompt: {current_prompt[:80]}...")
249+
250+
with torch.no_grad():
251+
# Pass embeddings directly to bypass internal encoders
252+
output = pipe(
253+
prompt_embeds=prompt_embeds,
254+
pooled_prompt_embeds=pooled_prompt_embeds,
255+
num_inference_steps=args.num_inference_steps,
256+
guidance_scale=args.guidance_scale,
257+
height=args.height,
258+
width=args.width,
259+
generator=generator,
260+
)
261+
262+
# Save output
263+
image = output.images[0]
264+
safe_prompt = (
265+
"".join(c if c.isalnum() or c in " _-" else "" for c in current_prompt)[:50].strip().replace(" ", "_")
266+
)
267+
output_filename = f"flux_{model_name}_sample{count:03d}_{safe_prompt}.png"
268+
image.save(output_dir / output_filename)
269+
print(f" ✅ Saved to: {output_filename}")
270+
271+
count += 1
272+
273+
print("\n" + "=" * 80 + "\nInference complete!\n" + "=" * 80)
274+
275+
276+
if __name__ == "__main__":
277+
main()

0 commit comments

Comments
 (0)