Skip to content

Commit 689953e

Browse files
PfannkuchensackJPPhotolstein
authored
Feature/zimage scheduler support (#8705)
* feat(flux): add scheduler selection for Flux models Add support for alternative diffusers Flow Matching schedulers: - Euler (default, 1st order) - Heun (2nd order, better quality, 2x slower) - LCM (optimized for few steps) Backend: - Add schedulers.py with scheduler type definitions and class mapping - Modify denoise.py to accept optional scheduler parameter - Add scheduler InputField to flux_denoise invocation (v4.2.0) Frontend: - Add fluxScheduler to Redux state and paramsSlice - Create ParamFluxScheduler component for Linear UI - Add scheduler to buildFLUXGraph for generation * feat(z-image): add scheduler selection for Z-Image models Add support for alternative diffusers Flow Matching schedulers for Z-Image: - Euler (default) - 1st order, optimized for Z-Image-Turbo (8 steps) - Heun (2nd order) - Better quality, 2x slower - LCM - Optimized for few-step generation Backend: - Extend schedulers.py with Z-Image scheduler types and mapping - Add scheduler InputField to z_image_denoise invocation (v1.3.0) - Refactor denoising loop to support diffusers schedulers Frontend: - Add zImageScheduler to Redux state in paramsSlice - Create ParamZImageScheduler component for Linear UI - Add scheduler to buildZImageGraph for generation * fix ruff check * fix(schedulers): prevent progress percentage overflow with LCM scheduler LCM scheduler may have more internal timesteps than user-facing steps, causing user_step to exceed total_steps. This resulted in progress percentage > 1.0, which caused a pydantic validation error. Fix: Only call step_callback when user_step <= total_steps. * Ruff format * fix(schedulers): remove initial step-0 callback for consistent step count Remove the initial step_callback at step=0 to match SD/SDXL behavior. Previously Flux/Z-Image showed N+1 steps (step 0 + N denoising steps), while SD/SDXL showed only N steps. Now all models display N steps consistently in the server log. * feat(z-image): add scheduler support with metadata recall - Handle LCM scheduler by using num_inference_steps instead of custom sigmas - Fix progress bar to show user-facing steps instead of internal scheduler steps - Pass scheduler parameter to Z-Image denoise node in graph builder - Add model-aware metadata recall for Flux and Z-Image schedulers --------- Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 61c2589 commit 689953e

File tree

15 files changed

+676
-98
lines changed

15 files changed

+676
-98
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
pack,
4848
unpack,
4949
)
50+
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
5051
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
5152
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType
5253
from invokeai.backend.patches.layer_patcher import LayerPatcher
@@ -63,7 +64,7 @@
6364
title="FLUX Denoise",
6465
tags=["image", "flux"],
6566
category="image",
66-
version="4.1.0",
67+
version="4.2.0",
6768
)
6869
class FluxDenoiseInvocation(BaseInvocation):
6970
"""Run denoising process with a FLUX transformer model."""
@@ -132,6 +133,12 @@ class FluxDenoiseInvocation(BaseInvocation):
132133
num_steps: int = InputField(
133134
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
134135
)
136+
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
137+
default="euler",
138+
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
139+
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
140+
ui_choice_labels=FLUX_SCHEDULER_LABELS,
141+
)
135142
guidance: float = InputField(
136143
default=4.0,
137144
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
@@ -242,6 +249,12 @@ def _run_diffusion(
242249
shift=not is_schnell,
243250
)
244251

252+
# Create scheduler if not using default euler
253+
scheduler = None
254+
if self.scheduler in FLUX_SCHEDULER_MAP:
255+
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
256+
scheduler = scheduler_class(num_train_timesteps=1000)
257+
245258
# Clip the timesteps schedule based on denoising_start and denoising_end.
246259
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
247260

@@ -426,6 +439,7 @@ def _run_diffusion(
426439
img_cond=img_cond,
427440
img_cond_seq=img_cond_seq,
428441
img_cond_seq_ids=img_cond_seq_ids,
442+
scheduler=scheduler,
429443
)
430444

431445
x = unpack(x.float(), self.height, self.width)

invokeai/app/invocations/z_image_denoise.py

Lines changed: 237 additions & 82 deletions
Large diffs are not rendered by default.

invokeai/backend/flux/denoise.py

Lines changed: 188 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import inspect
12
import math
23
from typing import Callable
34

45
import torch
6+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
57
from tqdm import tqdm
68

79
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
@@ -35,24 +37,199 @@ def denoise(
3537
# extra img tokens (sequence-wise) - for Kontext conditioning
3638
img_cond_seq: torch.Tensor | None = None,
3739
img_cond_seq_ids: torch.Tensor | None = None,
40+
# Optional scheduler for alternative sampling methods
41+
scheduler: SchedulerMixin | None = None,
3842
):
39-
# step 0 is the initial state
40-
total_steps = len(timesteps) - 1
41-
step_callback(
42-
PipelineIntermediateState(
43-
step=0,
44-
order=1,
45-
total_steps=total_steps,
46-
timestep=int(timesteps[0]),
47-
latents=img,
48-
),
49-
)
43+
# Determine if we're using a diffusers scheduler or the built-in Euler method
44+
use_scheduler = scheduler is not None
45+
46+
if use_scheduler:
47+
# Initialize scheduler with timesteps
48+
# The timesteps list contains values in [0, 1] range (sigmas)
49+
# Some schedulers (like Euler) support custom sigmas, others (like Heun) don't
50+
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
51+
if "sigmas" in set_timesteps_sig.parameters:
52+
# Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
53+
scheduler.set_timesteps(sigmas=timesteps, device=img.device)
54+
else:
55+
# Scheduler doesn't support custom sigmas - use num_inference_steps
56+
# The schedule will be computed by the scheduler itself
57+
num_inference_steps = len(timesteps) - 1
58+
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device)
59+
60+
# For schedulers like Heun, the number of actual steps may differ
61+
# (Heun doubles timesteps internally)
62+
num_scheduler_steps = len(scheduler.timesteps)
63+
# For user-facing step count, use the original number of denoising steps
64+
total_steps = len(timesteps) - 1
65+
else:
66+
total_steps = len(timesteps) - 1
67+
num_scheduler_steps = total_steps
68+
5069
# guidance_vec is ignored for schnell.
5170
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
5271

5372
# Store original sequence length for slicing predictions
5473
original_seq_len = img.shape[1]
5574

75+
# Track the actual step for user-facing progress (accounts for Heun's double steps)
76+
user_step = 0
77+
78+
if use_scheduler:
79+
# Use diffusers scheduler for stepping
80+
for step_index in tqdm(range(num_scheduler_steps)):
81+
timestep = scheduler.timesteps[step_index]
82+
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
83+
t_curr = timestep.item() / scheduler.config.num_train_timesteps
84+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
85+
86+
# For Heun scheduler, track if we're in first or second order step
87+
is_heun = hasattr(scheduler, "state_in_first_order")
88+
in_first_order = scheduler.state_in_first_order if is_heun else True
89+
90+
# Run ControlNet models
91+
controlnet_residuals: list[ControlNetFluxOutput] = []
92+
for controlnet_extension in controlnet_extensions:
93+
controlnet_residuals.append(
94+
controlnet_extension.run_controlnet(
95+
timestep_index=user_step,
96+
total_num_timesteps=total_steps,
97+
img=img,
98+
img_ids=img_ids,
99+
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
100+
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
101+
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
102+
timesteps=t_vec,
103+
guidance=guidance_vec,
104+
)
105+
)
106+
107+
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
108+
109+
# Prepare input for model
110+
img_input = img
111+
img_input_ids = img_ids
112+
113+
if img_cond is not None:
114+
img_input = torch.cat((img_input, img_cond), dim=-1)
115+
116+
if img_cond_seq is not None:
117+
assert img_cond_seq_ids is not None
118+
img_input = torch.cat((img_input, img_cond_seq), dim=1)
119+
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
120+
121+
pred = model(
122+
img=img_input,
123+
img_ids=img_input_ids,
124+
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
125+
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
126+
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
127+
timesteps=t_vec,
128+
guidance=guidance_vec,
129+
timestep_index=user_step,
130+
total_num_timesteps=total_steps,
131+
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
132+
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
133+
ip_adapter_extensions=pos_ip_adapter_extensions,
134+
regional_prompting_extension=pos_regional_prompting_extension,
135+
)
136+
137+
if img_cond_seq is not None:
138+
pred = pred[:, :original_seq_len]
139+
140+
# Get CFG scale for current user step
141+
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
142+
143+
if not math.isclose(step_cfg_scale, 1.0):
144+
if neg_regional_prompting_extension is None:
145+
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
146+
147+
neg_img_input = img
148+
neg_img_input_ids = img_ids
149+
150+
if img_cond is not None:
151+
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
152+
153+
if img_cond_seq is not None:
154+
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
155+
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
156+
157+
neg_pred = model(
158+
img=neg_img_input,
159+
img_ids=neg_img_input_ids,
160+
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
161+
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
162+
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
163+
timesteps=t_vec,
164+
guidance=guidance_vec,
165+
timestep_index=user_step,
166+
total_num_timesteps=total_steps,
167+
controlnet_double_block_residuals=None,
168+
controlnet_single_block_residuals=None,
169+
ip_adapter_extensions=neg_ip_adapter_extensions,
170+
regional_prompting_extension=neg_regional_prompting_extension,
171+
)
172+
173+
if img_cond_seq is not None:
174+
neg_pred = neg_pred[:, :original_seq_len]
175+
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
176+
177+
# Use scheduler.step() for the update
178+
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
179+
img = step_output.prev_sample
180+
181+
# Get t_prev for inpainting (next sigma value)
182+
if step_index + 1 < len(scheduler.sigmas):
183+
t_prev = scheduler.sigmas[step_index + 1].item()
184+
else:
185+
t_prev = 0.0
186+
187+
if inpaint_extension is not None:
188+
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
189+
190+
# For Heun, only increment user step after second-order step completes
191+
if is_heun:
192+
if not in_first_order:
193+
# Second order step completed
194+
user_step += 1
195+
# Only call step_callback if we haven't exceeded total_steps
196+
if user_step <= total_steps:
197+
preview_img = img - t_curr * pred
198+
if inpaint_extension is not None:
199+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
200+
preview_img, 0.0
201+
)
202+
step_callback(
203+
PipelineIntermediateState(
204+
step=user_step,
205+
order=2,
206+
total_steps=total_steps,
207+
timestep=int(t_curr * 1000),
208+
latents=preview_img,
209+
),
210+
)
211+
else:
212+
# For Euler, LCM and other first-order schedulers
213+
user_step += 1
214+
# Only call step_callback if we haven't exceeded total_steps
215+
# (LCM scheduler may have more internal steps than user-facing steps)
216+
if user_step <= total_steps:
217+
preview_img = img - t_curr * pred
218+
if inpaint_extension is not None:
219+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
220+
step_callback(
221+
PipelineIntermediateState(
222+
step=user_step,
223+
order=1,
224+
total_steps=total_steps,
225+
timestep=int(t_curr * 1000),
226+
latents=preview_img,
227+
),
228+
)
229+
230+
return img
231+
232+
# Original Euler implementation (when scheduler is None)
56233
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
57234
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
58235

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Flow Matching scheduler definitions and mapping.
2+
3+
This module provides the scheduler types and mapping for Flow Matching models
4+
(Flux and Z-Image), supporting multiple schedulers from the diffusers library.
5+
"""
6+
7+
from typing import Literal, Type
8+
9+
from diffusers import (
10+
FlowMatchEulerDiscreteScheduler,
11+
FlowMatchHeunDiscreteScheduler,
12+
)
13+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
14+
15+
# Note: FlowMatchLCMScheduler may not be available in all diffusers versions
16+
try:
17+
from diffusers import FlowMatchLCMScheduler
18+
19+
_HAS_LCM = True
20+
except ImportError:
21+
_HAS_LCM = False
22+
23+
# Scheduler name literal type for type checking
24+
FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
25+
26+
# Human-readable labels for the UI
27+
FLUX_SCHEDULER_LABELS: dict[str, str] = {
28+
"euler": "Euler",
29+
"heun": "Heun (2nd order)",
30+
"lcm": "LCM",
31+
}
32+
33+
# Mapping from scheduler names to scheduler classes
34+
FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
35+
"euler": FlowMatchEulerDiscreteScheduler,
36+
"heun": FlowMatchHeunDiscreteScheduler,
37+
}
38+
39+
if _HAS_LCM:
40+
FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
41+
42+
43+
# Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
44+
# Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
45+
# can be used for experimentation.
46+
ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
47+
48+
# Human-readable labels for the UI
49+
ZIMAGE_SCHEDULER_LABELS: dict[str, str] = {
50+
"euler": "Euler",
51+
"heun": "Heun (2nd order)",
52+
"lcm": "LCM",
53+
}
54+
55+
# Mapping from scheduler names to scheduler classes (same as Flux)
56+
ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
57+
"euler": FlowMatchEulerDiscreteScheduler,
58+
"heun": FlowMatchHeunDiscreteScheduler,
59+
}
60+
61+
if _HAS_LCM:
62+
ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler

invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ const slice = createSlice({
6969
setScheduler: (state, action: PayloadAction<ParameterScheduler>) => {
7070
state.scheduler = action.payload;
7171
},
72+
setFluxScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => {
73+
state.fluxScheduler = action.payload;
74+
},
75+
setZImageScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => {
76+
state.zImageScheduler = action.payload;
77+
},
7278
setUpscaleScheduler: (state, action: PayloadAction<ParameterScheduler>) => {
7379
state.upscaleScheduler = action.payload;
7480
},
@@ -449,6 +455,8 @@ export const {
449455
setCfgRescaleMultiplier,
450456
setGuidance,
451457
setScheduler,
458+
setFluxScheduler,
459+
setZImageScheduler,
452460
setUpscaleScheduler,
453461
setUpscaleCfgScale,
454462
setSeed,
@@ -588,6 +596,8 @@ export const selectModelSupportsOptimizedDenoising = createSelector(
588596
(model) => !!model && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base)
589597
);
590598
export const selectScheduler = createParamsSelector((params) => params.scheduler);
599+
export const selectFluxScheduler = createParamsSelector((params) => params.fluxScheduler);
600+
export const selectZImageScheduler = createParamsSelector((params) => params.zImageScheduler);
591601
export const selectSeamlessXAxis = createParamsSelector((params) => params.seamlessXAxis);
592602
export const selectSeamlessYAxis = createParamsSelector((params) => params.seamlessYAxis);
593603
export const selectSeed = createParamsSelector((params) => params.seed);

invokeai/frontend/web/src/features/controlLayers/store/types.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
zParameterCLIPGEmbedModel,
1010
zParameterCLIPLEmbedModel,
1111
zParameterControlLoRAModel,
12+
zParameterFluxScheduler,
1213
zParameterGuidance,
1314
zParameterImageDimension,
1415
zParameterMaskBlurMethod,
@@ -23,6 +24,7 @@ import {
2324
zParameterStrength,
2425
zParameterT5EncoderModel,
2526
zParameterVAEModel,
27+
zParameterZImageScheduler,
2628
} from 'features/parameters/types/parameterSchemas';
2729
import type { JsonObject } from 'type-fest';
2830
import { z } from 'zod';
@@ -596,6 +598,8 @@ export const zParamsState = z.object({
596598
optimizedDenoisingEnabled: z.boolean(),
597599
iterations: z.number(),
598600
scheduler: zParameterScheduler,
601+
fluxScheduler: zParameterFluxScheduler,
602+
zImageScheduler: zParameterZImageScheduler,
599603
upscaleScheduler: zParameterScheduler,
600604
upscaleCfgScale: zParameterCFGScale,
601605
seed: zParameterSeed,
@@ -650,6 +654,8 @@ export const getInitialParamsState = (): ParamsState => ({
650654
optimizedDenoisingEnabled: true,
651655
iterations: 1,
652656
scheduler: 'dpmpp_3m_k',
657+
fluxScheduler: 'euler',
658+
zImageScheduler: 'euler',
653659
upscaleScheduler: 'kdpm_2',
654660
upscaleCfgScale: 2,
655661
seed: 0,

0 commit comments

Comments
 (0)