Skip to content

Commit 99fc124

Browse files
feat(flux): add scheduler selection for Flux models
Add support for alternative diffusers Flow Matching schedulers: - Euler (default, 1st order) - Heun (2nd order, better quality, 2x slower) - LCM (optimized for few steps) Backend: - Add schedulers.py with scheduler type definitions and class mapping - Modify denoise.py to accept optional scheduler parameter - Add scheduler InputField to flux_denoise invocation (v4.2.0) Frontend: - Add fluxScheduler to Redux state and paramsSlice - Create ParamFluxScheduler component for Linear UI - Add scheduler to buildFLUXGraph for generation
1 parent 65efc3d commit 99fc124

File tree

11 files changed

+316
-6
lines changed

11 files changed

+316
-6
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
pack,
4848
unpack,
4949
)
50+
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
5051
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
5152
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType
5253
from invokeai.backend.patches.layer_patcher import LayerPatcher
@@ -63,7 +64,7 @@
6364
title="FLUX Denoise",
6465
tags=["image", "flux"],
6566
category="image",
66-
version="4.1.0",
67+
version="4.2.0",
6768
)
6869
class FluxDenoiseInvocation(BaseInvocation):
6970
"""Run denoising process with a FLUX transformer model."""
@@ -132,6 +133,12 @@ class FluxDenoiseInvocation(BaseInvocation):
132133
num_steps: int = InputField(
133134
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
134135
)
136+
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
137+
default="euler",
138+
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
139+
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
140+
ui_choice_labels=FLUX_SCHEDULER_LABELS,
141+
)
135142
guidance: float = InputField(
136143
default=4.0,
137144
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
@@ -242,6 +249,12 @@ def _run_diffusion(
242249
shift=not is_schnell,
243250
)
244251

252+
# Create scheduler if not using default euler
253+
scheduler = None
254+
if self.scheduler in FLUX_SCHEDULER_MAP:
255+
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
256+
scheduler = scheduler_class(num_train_timesteps=1000)
257+
245258
# Clip the timesteps schedule based on denoising_start and denoising_end.
246259
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
247260

@@ -426,6 +439,7 @@ def _run_diffusion(
426439
img_cond=img_cond,
427440
img_cond_seq=img_cond_seq,
428441
img_cond_seq_ids=img_cond_seq_ids,
442+
scheduler=scheduler,
429443
)
430444

431445
x = unpack(x.float(), self.height, self.width)

invokeai/backend/flux/denoise.py

Lines changed: 182 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import inspect
12
import math
23
from typing import Callable
34

45
import torch
6+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
57
from tqdm import tqdm
68

79
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
@@ -35,15 +37,41 @@ def denoise(
3537
# extra img tokens (sequence-wise) - for Kontext conditioning
3638
img_cond_seq: torch.Tensor | None = None,
3739
img_cond_seq_ids: torch.Tensor | None = None,
40+
# Optional scheduler for alternative sampling methods
41+
scheduler: SchedulerMixin | None = None,
3842
):
39-
# step 0 is the initial state
40-
total_steps = len(timesteps) - 1
43+
# Determine if we're using a diffusers scheduler or the built-in Euler method
44+
use_scheduler = scheduler is not None
45+
46+
if use_scheduler:
47+
# Initialize scheduler with timesteps
48+
# The timesteps list contains values in [0, 1] range (sigmas)
49+
# Some schedulers (like Euler) support custom sigmas, others (like Heun) don't
50+
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
51+
if "sigmas" in set_timesteps_sig.parameters:
52+
# Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
53+
scheduler.set_timesteps(sigmas=timesteps, device=img.device)
54+
else:
55+
# Scheduler doesn't support custom sigmas - use num_inference_steps
56+
# The schedule will be computed by the scheduler itself
57+
num_inference_steps = len(timesteps) - 1
58+
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device)
59+
60+
# For schedulers like Heun, the number of actual steps may differ
61+
# (Heun doubles timesteps internally)
62+
num_scheduler_steps = len(scheduler.timesteps)
63+
# For user-facing step count, use the original number of denoising steps
64+
total_steps = len(timesteps) - 1
65+
else:
66+
total_steps = len(timesteps) - 1
67+
num_scheduler_steps = total_steps
68+
4169
step_callback(
4270
PipelineIntermediateState(
4371
step=0,
4472
order=1,
4573
total_steps=total_steps,
46-
timestep=int(timesteps[0]),
74+
timestep=int(timesteps[0] * 1000) if use_scheduler else int(timesteps[0]),
4775
latents=img,
4876
),
4977
)
@@ -53,6 +81,157 @@ def denoise(
5381
# Store original sequence length for slicing predictions
5482
original_seq_len = img.shape[1]
5583

84+
# Track the actual step for user-facing progress (accounts for Heun's double steps)
85+
user_step = 0
86+
87+
if use_scheduler:
88+
# Use diffusers scheduler for stepping
89+
for step_index in tqdm(range(num_scheduler_steps)):
90+
timestep = scheduler.timesteps[step_index]
91+
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
92+
t_curr = timestep.item() / scheduler.config.num_train_timesteps
93+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
94+
95+
# For Heun scheduler, track if we're in first or second order step
96+
is_heun = hasattr(scheduler, "state_in_first_order")
97+
in_first_order = scheduler.state_in_first_order if is_heun else True
98+
99+
# Run ControlNet models
100+
controlnet_residuals: list[ControlNetFluxOutput] = []
101+
for controlnet_extension in controlnet_extensions:
102+
controlnet_residuals.append(
103+
controlnet_extension.run_controlnet(
104+
timestep_index=user_step,
105+
total_num_timesteps=total_steps,
106+
img=img,
107+
img_ids=img_ids,
108+
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
109+
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
110+
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
111+
timesteps=t_vec,
112+
guidance=guidance_vec,
113+
)
114+
)
115+
116+
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
117+
118+
# Prepare input for model
119+
img_input = img
120+
img_input_ids = img_ids
121+
122+
if img_cond is not None:
123+
img_input = torch.cat((img_input, img_cond), dim=-1)
124+
125+
if img_cond_seq is not None:
126+
assert img_cond_seq_ids is not None
127+
img_input = torch.cat((img_input, img_cond_seq), dim=1)
128+
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
129+
130+
pred = model(
131+
img=img_input,
132+
img_ids=img_input_ids,
133+
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
134+
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
135+
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
136+
timesteps=t_vec,
137+
guidance=guidance_vec,
138+
timestep_index=user_step,
139+
total_num_timesteps=total_steps,
140+
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
141+
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
142+
ip_adapter_extensions=pos_ip_adapter_extensions,
143+
regional_prompting_extension=pos_regional_prompting_extension,
144+
)
145+
146+
if img_cond_seq is not None:
147+
pred = pred[:, :original_seq_len]
148+
149+
# Get CFG scale for current user step
150+
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
151+
152+
if not math.isclose(step_cfg_scale, 1.0):
153+
if neg_regional_prompting_extension is None:
154+
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
155+
156+
neg_img_input = img
157+
neg_img_input_ids = img_ids
158+
159+
if img_cond is not None:
160+
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
161+
162+
if img_cond_seq is not None:
163+
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
164+
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
165+
166+
neg_pred = model(
167+
img=neg_img_input,
168+
img_ids=neg_img_input_ids,
169+
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
170+
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
171+
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
172+
timesteps=t_vec,
173+
guidance=guidance_vec,
174+
timestep_index=user_step,
175+
total_num_timesteps=total_steps,
176+
controlnet_double_block_residuals=None,
177+
controlnet_single_block_residuals=None,
178+
ip_adapter_extensions=neg_ip_adapter_extensions,
179+
regional_prompting_extension=neg_regional_prompting_extension,
180+
)
181+
182+
if img_cond_seq is not None:
183+
neg_pred = neg_pred[:, :original_seq_len]
184+
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
185+
186+
# Use scheduler.step() for the update
187+
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
188+
img = step_output.prev_sample
189+
190+
# Get t_prev for inpainting (next sigma value)
191+
if step_index + 1 < len(scheduler.sigmas):
192+
t_prev = scheduler.sigmas[step_index + 1].item()
193+
else:
194+
t_prev = 0.0
195+
196+
if inpaint_extension is not None:
197+
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
198+
199+
# For Heun, only increment user step after second-order step completes
200+
if is_heun:
201+
if not in_first_order:
202+
# Second order step completed
203+
user_step += 1
204+
preview_img = img - t_curr * pred
205+
if inpaint_extension is not None:
206+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
207+
step_callback(
208+
PipelineIntermediateState(
209+
step=user_step,
210+
order=2,
211+
total_steps=total_steps,
212+
timestep=int(t_curr * 1000),
213+
latents=preview_img,
214+
),
215+
)
216+
else:
217+
# For Euler and other first-order schedulers
218+
user_step += 1
219+
preview_img = img - t_curr * pred
220+
if inpaint_extension is not None:
221+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
222+
step_callback(
223+
PipelineIntermediateState(
224+
step=user_step,
225+
order=1,
226+
total_steps=total_steps,
227+
timestep=int(t_curr * 1000),
228+
latents=preview_img,
229+
),
230+
)
231+
232+
return img
233+
234+
# Original Euler implementation (when scheduler is None)
56235
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
57236
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
58237

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Flux scheduler definitions and mapping.
2+
3+
This module provides the scheduler types and mapping for Flux models,
4+
supporting multiple Flow Matching schedulers from the diffusers library.
5+
"""
6+
7+
from typing import Literal, Type
8+
9+
from diffusers import (
10+
FlowMatchEulerDiscreteScheduler,
11+
FlowMatchHeunDiscreteScheduler,
12+
)
13+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
14+
15+
# Note: FlowMatchLCMScheduler may not be available in all diffusers versions
16+
try:
17+
from diffusers import FlowMatchLCMScheduler
18+
19+
_HAS_LCM = True
20+
except ImportError:
21+
_HAS_LCM = False
22+
23+
# Scheduler name literal type for type checking
24+
FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
25+
26+
# Human-readable labels for the UI
27+
FLUX_SCHEDULER_LABELS: dict[str, str] = {
28+
"euler": "Euler",
29+
"heun": "Heun (2nd order)",
30+
"lcm": "LCM",
31+
}
32+
33+
# Mapping from scheduler names to scheduler classes
34+
FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
35+
"euler": FlowMatchEulerDiscreteScheduler,
36+
"heun": FlowMatchHeunDiscreteScheduler,
37+
}
38+
39+
if _HAS_LCM:
40+
FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler

invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ const slice = createSlice({
6969
setScheduler: (state, action: PayloadAction<ParameterScheduler>) => {
7070
state.scheduler = action.payload;
7171
},
72+
setFluxScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => {
73+
state.fluxScheduler = action.payload;
74+
},
7275
setUpscaleScheduler: (state, action: PayloadAction<ParameterScheduler>) => {
7376
state.upscaleScheduler = action.payload;
7477
},
@@ -449,6 +452,7 @@ export const {
449452
setCfgRescaleMultiplier,
450453
setGuidance,
451454
setScheduler,
455+
setFluxScheduler,
452456
setUpscaleScheduler,
453457
setUpscaleCfgScale,
454458
setSeed,
@@ -588,6 +592,7 @@ export const selectModelSupportsOptimizedDenoising = createSelector(
588592
(model) => !!model && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base)
589593
);
590594
export const selectScheduler = createParamsSelector((params) => params.scheduler);
595+
export const selectFluxScheduler = createParamsSelector((params) => params.fluxScheduler);
591596
export const selectSeamlessXAxis = createParamsSelector((params) => params.seamlessXAxis);
592597
export const selectSeamlessYAxis = createParamsSelector((params) => params.seamlessYAxis);
593598
export const selectSeed = createParamsSelector((params) => params.seed);

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
zParameterCLIPGEmbedModel,
1010
zParameterCLIPLEmbedModel,
1111
zParameterControlLoRAModel,
12+
zParameterFluxScheduler,
1213
zParameterGuidance,
1314
zParameterImageDimension,
1415
zParameterMaskBlurMethod,
@@ -596,6 +597,7 @@ export const zParamsState = z.object({
596597
optimizedDenoisingEnabled: z.boolean(),
597598
iterations: z.number(),
598599
scheduler: zParameterScheduler,
600+
fluxScheduler: zParameterFluxScheduler,
599601
upscaleScheduler: zParameterScheduler,
600602
upscaleCfgScale: zParameterCFGScale,
601603
seed: zParameterSeed,
@@ -650,6 +652,7 @@ export const getInitialParamsState = (): ParamsState => ({
650652
optimizedDenoisingEnabled: true,
651653
iterations: 1,
652654
scheduler: 'dpmpp_3m_k',
655+
fluxScheduler: 'euler',
653656
upscaleScheduler: 'kdpm_2',
654657
upscaleCfgScale: 2,
655658
seed: 0,

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ export const zSchedulerField = z.enum([
6464
'tcd',
6565
]);
6666
export type SchedulerField = z.infer<typeof zSchedulerField>;
67+
68+
// Flux-specific scheduler options (Flow Matching schedulers)
69+
export const zFluxSchedulerField = z.enum(['euler', 'heun', 'lcm']);
6770
// #endregion
6871

6972
// #region Model-related schemas

invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
4343
const canvas = selectCanvasSlice(state);
4444
const refImages = selectRefImagesSlice(state);
4545

46-
const { guidance: baseGuidance, steps, fluxVAE, t5EncoderModel, clipEmbedModel } = params;
46+
const { guidance: baseGuidance, steps, fluxScheduler, fluxVAE, t5EncoderModel, clipEmbedModel } = params;
4747

4848
assert(t5EncoderModel, 'No T5 Encoder model found in state');
4949
assert(clipEmbedModel, 'No CLIP Embed model found in state');
@@ -114,6 +114,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
114114
id: getPrefixedId('flux_denoise'),
115115
guidance,
116116
num_steps: steps,
117+
scheduler: fluxScheduler,
117118
});
118119

119120
const l2i = g.addNode({

0 commit comments

Comments
 (0)