1+ import inspect
12import math
23from typing import Callable
34
45import torch
6+ from diffusers .schedulers .scheduling_utils import SchedulerMixin
57from tqdm import tqdm
68
79from invokeai .backend .flux .controlnet .controlnet_flux_output import ControlNetFluxOutput , sum_controlnet_flux_outputs
@@ -35,15 +37,41 @@ def denoise(
3537 # extra img tokens (sequence-wise) - for Kontext conditioning
3638 img_cond_seq : torch .Tensor | None = None ,
3739 img_cond_seq_ids : torch .Tensor | None = None ,
40+ # Optional scheduler for alternative sampling methods
41+ scheduler : SchedulerMixin | None = None ,
3842):
39- # step 0 is the initial state
40- total_steps = len (timesteps ) - 1
43+ # Determine if we're using a diffusers scheduler or the built-in Euler method
44+ use_scheduler = scheduler is not None
45+
46+ if use_scheduler :
47+ # Initialize scheduler with timesteps
48+ # The timesteps list contains values in [0, 1] range (sigmas)
49+ # Some schedulers (like Euler) support custom sigmas, others (like Heun) don't
50+ set_timesteps_sig = inspect .signature (scheduler .set_timesteps )
51+ if "sigmas" in set_timesteps_sig .parameters :
52+ # Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
53+ scheduler .set_timesteps (sigmas = timesteps , device = img .device )
54+ else :
55+ # Scheduler doesn't support custom sigmas - use num_inference_steps
56+ # The schedule will be computed by the scheduler itself
57+ num_inference_steps = len (timesteps ) - 1
58+ scheduler .set_timesteps (num_inference_steps = num_inference_steps , device = img .device )
59+
60+ # For schedulers like Heun, the number of actual steps may differ
61+ # (Heun doubles timesteps internally)
62+ num_scheduler_steps = len (scheduler .timesteps )
63+ # For user-facing step count, use the original number of denoising steps
64+ total_steps = len (timesteps ) - 1
65+ else :
66+ total_steps = len (timesteps ) - 1
67+ num_scheduler_steps = total_steps
68+
4169 step_callback (
4270 PipelineIntermediateState (
4371 step = 0 ,
4472 order = 1 ,
4573 total_steps = total_steps ,
46- timestep = int (timesteps [0 ]),
74+ timestep = int (timesteps [0 ] * 1000 ) if use_scheduler else int ( timesteps [ 0 ] ),
4775 latents = img ,
4876 ),
4977 )
@@ -53,6 +81,157 @@ def denoise(
5381 # Store original sequence length for slicing predictions
5482 original_seq_len = img .shape [1 ]
5583
84+ # Track the actual step for user-facing progress (accounts for Heun's double steps)
85+ user_step = 0
86+
87+ if use_scheduler :
88+ # Use diffusers scheduler for stepping
89+ for step_index in tqdm (range (num_scheduler_steps )):
90+ timestep = scheduler .timesteps [step_index ]
91+ # Convert scheduler timestep (0-1000) to normalized (0-1) for the model
92+ t_curr = timestep .item () / scheduler .config .num_train_timesteps
93+ t_vec = torch .full ((img .shape [0 ],), t_curr , dtype = img .dtype , device = img .device )
94+
95+ # For Heun scheduler, track if we're in first or second order step
96+ is_heun = hasattr (scheduler , "state_in_first_order" )
97+ in_first_order = scheduler .state_in_first_order if is_heun else True
98+
99+ # Run ControlNet models
100+ controlnet_residuals : list [ControlNetFluxOutput ] = []
101+ for controlnet_extension in controlnet_extensions :
102+ controlnet_residuals .append (
103+ controlnet_extension .run_controlnet (
104+ timestep_index = user_step ,
105+ total_num_timesteps = total_steps ,
106+ img = img ,
107+ img_ids = img_ids ,
108+ txt = pos_regional_prompting_extension .regional_text_conditioning .t5_embeddings ,
109+ txt_ids = pos_regional_prompting_extension .regional_text_conditioning .t5_txt_ids ,
110+ y = pos_regional_prompting_extension .regional_text_conditioning .clip_embeddings ,
111+ timesteps = t_vec ,
112+ guidance = guidance_vec ,
113+ )
114+ )
115+
116+ merged_controlnet_residuals = sum_controlnet_flux_outputs (controlnet_residuals )
117+
118+ # Prepare input for model
119+ img_input = img
120+ img_input_ids = img_ids
121+
122+ if img_cond is not None :
123+ img_input = torch .cat ((img_input , img_cond ), dim = - 1 )
124+
125+ if img_cond_seq is not None :
126+ assert img_cond_seq_ids is not None
127+ img_input = torch .cat ((img_input , img_cond_seq ), dim = 1 )
128+ img_input_ids = torch .cat ((img_input_ids , img_cond_seq_ids ), dim = 1 )
129+
130+ pred = model (
131+ img = img_input ,
132+ img_ids = img_input_ids ,
133+ txt = pos_regional_prompting_extension .regional_text_conditioning .t5_embeddings ,
134+ txt_ids = pos_regional_prompting_extension .regional_text_conditioning .t5_txt_ids ,
135+ y = pos_regional_prompting_extension .regional_text_conditioning .clip_embeddings ,
136+ timesteps = t_vec ,
137+ guidance = guidance_vec ,
138+ timestep_index = user_step ,
139+ total_num_timesteps = total_steps ,
140+ controlnet_double_block_residuals = merged_controlnet_residuals .double_block_residuals ,
141+ controlnet_single_block_residuals = merged_controlnet_residuals .single_block_residuals ,
142+ ip_adapter_extensions = pos_ip_adapter_extensions ,
143+ regional_prompting_extension = pos_regional_prompting_extension ,
144+ )
145+
146+ if img_cond_seq is not None :
147+ pred = pred [:, :original_seq_len ]
148+
149+ # Get CFG scale for current user step
150+ step_cfg_scale = cfg_scale [min (user_step , len (cfg_scale ) - 1 )]
151+
152+ if not math .isclose (step_cfg_scale , 1.0 ):
153+ if neg_regional_prompting_extension is None :
154+ raise ValueError ("Negative text conditioning is required when cfg_scale is not 1.0." )
155+
156+ neg_img_input = img
157+ neg_img_input_ids = img_ids
158+
159+ if img_cond is not None :
160+ neg_img_input = torch .cat ((neg_img_input , img_cond ), dim = - 1 )
161+
162+ if img_cond_seq is not None :
163+ neg_img_input = torch .cat ((neg_img_input , img_cond_seq ), dim = 1 )
164+ neg_img_input_ids = torch .cat ((neg_img_input_ids , img_cond_seq_ids ), dim = 1 )
165+
166+ neg_pred = model (
167+ img = neg_img_input ,
168+ img_ids = neg_img_input_ids ,
169+ txt = neg_regional_prompting_extension .regional_text_conditioning .t5_embeddings ,
170+ txt_ids = neg_regional_prompting_extension .regional_text_conditioning .t5_txt_ids ,
171+ y = neg_regional_prompting_extension .regional_text_conditioning .clip_embeddings ,
172+ timesteps = t_vec ,
173+ guidance = guidance_vec ,
174+ timestep_index = user_step ,
175+ total_num_timesteps = total_steps ,
176+ controlnet_double_block_residuals = None ,
177+ controlnet_single_block_residuals = None ,
178+ ip_adapter_extensions = neg_ip_adapter_extensions ,
179+ regional_prompting_extension = neg_regional_prompting_extension ,
180+ )
181+
182+ if img_cond_seq is not None :
183+ neg_pred = neg_pred [:, :original_seq_len ]
184+ pred = neg_pred + step_cfg_scale * (pred - neg_pred )
185+
186+ # Use scheduler.step() for the update
187+ step_output = scheduler .step (model_output = pred , timestep = timestep , sample = img )
188+ img = step_output .prev_sample
189+
190+ # Get t_prev for inpainting (next sigma value)
191+ if step_index + 1 < len (scheduler .sigmas ):
192+ t_prev = scheduler .sigmas [step_index + 1 ].item ()
193+ else :
194+ t_prev = 0.0
195+
196+ if inpaint_extension is not None :
197+ img = inpaint_extension .merge_intermediate_latents_with_init_latents (img , t_prev )
198+
199+ # For Heun, only increment user step after second-order step completes
200+ if is_heun :
201+ if not in_first_order :
202+ # Second order step completed
203+ user_step += 1
204+ preview_img = img - t_curr * pred
205+ if inpaint_extension is not None :
206+ preview_img = inpaint_extension .merge_intermediate_latents_with_init_latents (preview_img , 0.0 )
207+ step_callback (
208+ PipelineIntermediateState (
209+ step = user_step ,
210+ order = 2 ,
211+ total_steps = total_steps ,
212+ timestep = int (t_curr * 1000 ),
213+ latents = preview_img ,
214+ ),
215+ )
216+ else :
217+ # For Euler and other first-order schedulers
218+ user_step += 1
219+ preview_img = img - t_curr * pred
220+ if inpaint_extension is not None :
221+ preview_img = inpaint_extension .merge_intermediate_latents_with_init_latents (preview_img , 0.0 )
222+ step_callback (
223+ PipelineIntermediateState (
224+ step = user_step ,
225+ order = 1 ,
226+ total_steps = total_steps ,
227+ timestep = int (t_curr * 1000 ),
228+ latents = preview_img ,
229+ ),
230+ )
231+
232+ return img
233+
234+ # Original Euler implementation (when scheduler is None)
56235 for step_index , (t_curr , t_prev ) in tqdm (list (enumerate (zip (timesteps [:- 1 ], timesteps [1 :], strict = True )))):
57236 t_vec = torch .full ((img .shape [0 ],), t_curr , dtype = img .dtype , device = img .device )
58237
0 commit comments