|
| 1 | +import inspect |
1 | 2 | import math |
2 | 3 | from typing import Callable |
3 | 4 |
|
4 | 5 | import torch |
| 6 | +from diffusers.schedulers.scheduling_utils import SchedulerMixin |
5 | 7 | from tqdm import tqdm |
6 | 8 |
|
7 | 9 | from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs |
@@ -35,24 +37,199 @@ def denoise( |
35 | 37 | # extra img tokens (sequence-wise) - for Kontext conditioning |
36 | 38 | img_cond_seq: torch.Tensor | None = None, |
37 | 39 | img_cond_seq_ids: torch.Tensor | None = None, |
| 40 | + # Optional scheduler for alternative sampling methods |
| 41 | + scheduler: SchedulerMixin | None = None, |
38 | 42 | ): |
39 | | - # step 0 is the initial state |
40 | | - total_steps = len(timesteps) - 1 |
41 | | - step_callback( |
42 | | - PipelineIntermediateState( |
43 | | - step=0, |
44 | | - order=1, |
45 | | - total_steps=total_steps, |
46 | | - timestep=int(timesteps[0]), |
47 | | - latents=img, |
48 | | - ), |
49 | | - ) |
| 43 | + # Determine if we're using a diffusers scheduler or the built-in Euler method |
| 44 | + use_scheduler = scheduler is not None |
| 45 | + |
| 46 | + if use_scheduler: |
| 47 | + # Initialize scheduler with timesteps |
| 48 | + # The timesteps list contains values in [0, 1] range (sigmas) |
| 49 | + # Some schedulers (like Euler) support custom sigmas, others (like Heun) don't |
| 50 | + set_timesteps_sig = inspect.signature(scheduler.set_timesteps) |
| 51 | + if "sigmas" in set_timesteps_sig.parameters: |
| 52 | + # Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule |
| 53 | + scheduler.set_timesteps(sigmas=timesteps, device=img.device) |
| 54 | + else: |
| 55 | + # Scheduler doesn't support custom sigmas - use num_inference_steps |
| 56 | + # The schedule will be computed by the scheduler itself |
| 57 | + num_inference_steps = len(timesteps) - 1 |
| 58 | + scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device) |
| 59 | + |
| 60 | + # For schedulers like Heun, the number of actual steps may differ |
| 61 | + # (Heun doubles timesteps internally) |
| 62 | + num_scheduler_steps = len(scheduler.timesteps) |
| 63 | + # For user-facing step count, use the original number of denoising steps |
| 64 | + total_steps = len(timesteps) - 1 |
| 65 | + else: |
| 66 | + total_steps = len(timesteps) - 1 |
| 67 | + num_scheduler_steps = total_steps |
| 68 | + |
50 | 69 | # guidance_vec is ignored for schnell. |
51 | 70 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) |
52 | 71 |
|
53 | 72 | # Store original sequence length for slicing predictions |
54 | 73 | original_seq_len = img.shape[1] |
55 | 74 |
|
| 75 | + # Track the actual step for user-facing progress (accounts for Heun's double steps) |
| 76 | + user_step = 0 |
| 77 | + |
| 78 | + if use_scheduler: |
| 79 | + # Use diffusers scheduler for stepping |
| 80 | + for step_index in tqdm(range(num_scheduler_steps)): |
| 81 | + timestep = scheduler.timesteps[step_index] |
| 82 | + # Convert scheduler timestep (0-1000) to normalized (0-1) for the model |
| 83 | + t_curr = timestep.item() / scheduler.config.num_train_timesteps |
| 84 | + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) |
| 85 | + |
| 86 | + # For Heun scheduler, track if we're in first or second order step |
| 87 | + is_heun = hasattr(scheduler, "state_in_first_order") |
| 88 | + in_first_order = scheduler.state_in_first_order if is_heun else True |
| 89 | + |
| 90 | + # Run ControlNet models |
| 91 | + controlnet_residuals: list[ControlNetFluxOutput] = [] |
| 92 | + for controlnet_extension in controlnet_extensions: |
| 93 | + controlnet_residuals.append( |
| 94 | + controlnet_extension.run_controlnet( |
| 95 | + timestep_index=user_step, |
| 96 | + total_num_timesteps=total_steps, |
| 97 | + img=img, |
| 98 | + img_ids=img_ids, |
| 99 | + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, |
| 100 | + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, |
| 101 | + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, |
| 102 | + timesteps=t_vec, |
| 103 | + guidance=guidance_vec, |
| 104 | + ) |
| 105 | + ) |
| 106 | + |
| 107 | + merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals) |
| 108 | + |
| 109 | + # Prepare input for model |
| 110 | + img_input = img |
| 111 | + img_input_ids = img_ids |
| 112 | + |
| 113 | + if img_cond is not None: |
| 114 | + img_input = torch.cat((img_input, img_cond), dim=-1) |
| 115 | + |
| 116 | + if img_cond_seq is not None: |
| 117 | + assert img_cond_seq_ids is not None |
| 118 | + img_input = torch.cat((img_input, img_cond_seq), dim=1) |
| 119 | + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) |
| 120 | + |
| 121 | + pred = model( |
| 122 | + img=img_input, |
| 123 | + img_ids=img_input_ids, |
| 124 | + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, |
| 125 | + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, |
| 126 | + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, |
| 127 | + timesteps=t_vec, |
| 128 | + guidance=guidance_vec, |
| 129 | + timestep_index=user_step, |
| 130 | + total_num_timesteps=total_steps, |
| 131 | + controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, |
| 132 | + controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, |
| 133 | + ip_adapter_extensions=pos_ip_adapter_extensions, |
| 134 | + regional_prompting_extension=pos_regional_prompting_extension, |
| 135 | + ) |
| 136 | + |
| 137 | + if img_cond_seq is not None: |
| 138 | + pred = pred[:, :original_seq_len] |
| 139 | + |
| 140 | + # Get CFG scale for current user step |
| 141 | + step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)] |
| 142 | + |
| 143 | + if not math.isclose(step_cfg_scale, 1.0): |
| 144 | + if neg_regional_prompting_extension is None: |
| 145 | + raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.") |
| 146 | + |
| 147 | + neg_img_input = img |
| 148 | + neg_img_input_ids = img_ids |
| 149 | + |
| 150 | + if img_cond is not None: |
| 151 | + neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1) |
| 152 | + |
| 153 | + if img_cond_seq is not None: |
| 154 | + neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1) |
| 155 | + neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1) |
| 156 | + |
| 157 | + neg_pred = model( |
| 158 | + img=neg_img_input, |
| 159 | + img_ids=neg_img_input_ids, |
| 160 | + txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings, |
| 161 | + txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, |
| 162 | + y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings, |
| 163 | + timesteps=t_vec, |
| 164 | + guidance=guidance_vec, |
| 165 | + timestep_index=user_step, |
| 166 | + total_num_timesteps=total_steps, |
| 167 | + controlnet_double_block_residuals=None, |
| 168 | + controlnet_single_block_residuals=None, |
| 169 | + ip_adapter_extensions=neg_ip_adapter_extensions, |
| 170 | + regional_prompting_extension=neg_regional_prompting_extension, |
| 171 | + ) |
| 172 | + |
| 173 | + if img_cond_seq is not None: |
| 174 | + neg_pred = neg_pred[:, :original_seq_len] |
| 175 | + pred = neg_pred + step_cfg_scale * (pred - neg_pred) |
| 176 | + |
| 177 | + # Use scheduler.step() for the update |
| 178 | + step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img) |
| 179 | + img = step_output.prev_sample |
| 180 | + |
| 181 | + # Get t_prev for inpainting (next sigma value) |
| 182 | + if step_index + 1 < len(scheduler.sigmas): |
| 183 | + t_prev = scheduler.sigmas[step_index + 1].item() |
| 184 | + else: |
| 185 | + t_prev = 0.0 |
| 186 | + |
| 187 | + if inpaint_extension is not None: |
| 188 | + img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev) |
| 189 | + |
| 190 | + # For Heun, only increment user step after second-order step completes |
| 191 | + if is_heun: |
| 192 | + if not in_first_order: |
| 193 | + # Second order step completed |
| 194 | + user_step += 1 |
| 195 | + # Only call step_callback if we haven't exceeded total_steps |
| 196 | + if user_step <= total_steps: |
| 197 | + preview_img = img - t_curr * pred |
| 198 | + if inpaint_extension is not None: |
| 199 | + preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents( |
| 200 | + preview_img, 0.0 |
| 201 | + ) |
| 202 | + step_callback( |
| 203 | + PipelineIntermediateState( |
| 204 | + step=user_step, |
| 205 | + order=2, |
| 206 | + total_steps=total_steps, |
| 207 | + timestep=int(t_curr * 1000), |
| 208 | + latents=preview_img, |
| 209 | + ), |
| 210 | + ) |
| 211 | + else: |
| 212 | + # For Euler, LCM and other first-order schedulers |
| 213 | + user_step += 1 |
| 214 | + # Only call step_callback if we haven't exceeded total_steps |
| 215 | + # (LCM scheduler may have more internal steps than user-facing steps) |
| 216 | + if user_step <= total_steps: |
| 217 | + preview_img = img - t_curr * pred |
| 218 | + if inpaint_extension is not None: |
| 219 | + preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0) |
| 220 | + step_callback( |
| 221 | + PipelineIntermediateState( |
| 222 | + step=user_step, |
| 223 | + order=1, |
| 224 | + total_steps=total_steps, |
| 225 | + timestep=int(t_curr * 1000), |
| 226 | + latents=preview_img, |
| 227 | + ), |
| 228 | + ) |
| 229 | + |
| 230 | + return img |
| 231 | + |
| 232 | + # Original Euler implementation (when scheduler is None) |
56 | 233 | for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): |
57 | 234 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) |
58 | 235 |
|
|
0 commit comments