diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index be2c7d70fa3..085dc3a2534 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -32,6 +32,8 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux from invokeai.backend.flux.denoise import denoise +from invokeai.backend.flux.dype.presets import DyPEPreset, get_dype_config_from_preset +from invokeai.backend.flux.extensions.dype_extension import DyPEExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension from invokeai.backend.flux.extensions.kontext_extension import KontextExtension from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension @@ -64,7 +66,7 @@ title="FLUX Denoise", tags=["image", "flux"], category="image", - version="4.2.0", + version="4.3.0", ) class FluxDenoiseInvocation(BaseInvocation): """Run denoising process with a FLUX transformer model.""" @@ -166,6 +168,24 @@ class FluxDenoiseInvocation(BaseInvocation): input=Input.Connection, ) + # DyPE (Dynamic Position Extrapolation) for high-resolution generation + dype_preset: DyPEPreset = InputField( + default=DyPEPreset.OFF, + description="DyPE preset for high-resolution generation. 'auto' enables automatically for resolutions > 1536px. '4k' uses optimized settings for 4K output.", + ) + dype_scale: Optional[float] = InputField( + default=None, + ge=0.0, + le=8.0, + description="DyPE magnitude (λs). Higher values = stronger extrapolation. Only used when dype_preset is not 'off'.", + ) + dype_exponent: Optional[float] = InputField( + default=None, + ge=0.0, + le=1000.0, + description="DyPE decay speed (λt). Controls transition from low to high frequency detail. Only used when dype_preset is not 'off'.", + ) + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: latents = self._run_diffusion(context) @@ -422,6 +442,26 @@ def _run_diffusion( kontext_extension.ensure_batch_size(x.shape[0]) img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids + # Prepare DyPE extension for high-resolution generation + dype_extension: DyPEExtension | None = None + dype_config = get_dype_config_from_preset( + preset=self.dype_preset, + width=self.width, + height=self.height, + custom_scale=self.dype_scale, + custom_exponent=self.dype_exponent, + ) + if dype_config is not None: + dype_extension = DyPEExtension( + config=dype_config, + target_height=self.height, + target_width=self.width, + ) + context.logger.info( + f"DyPE enabled: {self.width}x{self.height}, preset={self.dype_preset.value}, " + f"scale={dype_config.dype_scale:.2f}, method={dype_config.method}" + ) + x = denoise( model=transformer, img=x, @@ -439,6 +479,7 @@ def _run_diffusion( img_cond=img_cond, img_cond_seq=img_cond_seq, img_cond_seq_ids=img_cond_seq_ids, + dype_extension=dype_extension, scheduler=scheduler, ) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 1d9492a5df7..30d075a5270 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -7,6 +7,7 @@ from tqdm import tqdm from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs +from invokeai.backend.flux.extensions.dype_extension import DyPEExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension @@ -37,6 +38,8 @@ def denoise( # extra img tokens (sequence-wise) - for Kontext conditioning img_cond_seq: torch.Tensor | None = None, img_cond_seq_ids: torch.Tensor | None = None, + # DyPE extension for high-resolution generation + dype_extension: DyPEExtension | None = None, # Optional scheduler for alternative sampling methods scheduler: SchedulerMixin | None = None, ): @@ -74,30 +77,206 @@ def denoise( # Store original sequence length for slicing predictions original_seq_len = img.shape[1] - # Track the actual step for user-facing progress (accounts for Heun's double steps) - user_step = 0 + # DyPE: Patch model with DyPE-aware position embedder + dype_embedder = None + original_pe_embedder = None + if dype_extension is not None: + dype_embedder, original_pe_embedder = dype_extension.patch_model(model) + + try: + # Track the actual step for user-facing progress (accounts for Heun's double steps) + user_step = 0 + + if use_scheduler: + # Use diffusers scheduler for stepping + # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps) + # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps + pbar = tqdm(total=total_steps, desc="Denoising") + for step_index in range(num_scheduler_steps): + timestep = scheduler.timesteps[step_index] + # Convert scheduler timestep (0-1000) to normalized (0-1) for the model + t_curr = timestep.item() / scheduler.config.num_train_timesteps + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # DyPE: Update step state for timestep-dependent scaling + if dype_extension is not None and dype_embedder is not None: + dype_extension.update_step_state( + embedder=dype_embedder, + timestep=t_curr, + timestep_index=user_step, + total_steps=total_steps, + ) - if use_scheduler: - # Use diffusers scheduler for stepping - # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps) - # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps - pbar = tqdm(total=total_steps, desc="Denoising") - for step_index in range(num_scheduler_steps): - timestep = scheduler.timesteps[step_index] - # Convert scheduler timestep (0-1000) to normalized (0-1) for the model - t_curr = timestep.item() / scheduler.config.num_train_timesteps - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + # For Heun scheduler, track if we're in first or second order step + is_heun = hasattr(scheduler, "state_in_first_order") + in_first_order = scheduler.state_in_first_order if is_heun else True + + # Run ControlNet models + controlnet_residuals: list[ControlNetFluxOutput] = [] + for controlnet_extension in controlnet_extensions: + controlnet_residuals.append( + controlnet_extension.run_controlnet( + timestep_index=user_step, + total_num_timesteps=total_steps, + img=img, + img_ids=img_ids, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, + timesteps=t_vec, + guidance=guidance_vec, + ) + ) + + merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals) + + # Prepare input for model + img_input = img + img_input_ids = img_ids + + if img_cond is not None: + img_input = torch.cat((img_input, img_cond), dim=-1) + + if img_cond_seq is not None: + assert img_cond_seq_ids is not None + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + + pred = model( + img=img_input, + img_ids=img_input_ids, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, + timesteps=t_vec, + guidance=guidance_vec, + timestep_index=user_step, + total_num_timesteps=total_steps, + controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, + controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, + ip_adapter_extensions=pos_ip_adapter_extensions, + regional_prompting_extension=pos_regional_prompting_extension, + ) + + if img_cond_seq is not None: + pred = pred[:, :original_seq_len] + + # Get CFG scale for current user step + step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)] + + if not math.isclose(step_cfg_scale, 1.0): + if neg_regional_prompting_extension is None: + raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.") + + neg_img_input = img + neg_img_input_ids = img_ids + + if img_cond is not None: + neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1) + + if img_cond_seq is not None: + neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1) + neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1) + + neg_pred = model( + img=neg_img_input, + img_ids=neg_img_input_ids, + txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings, + timesteps=t_vec, + guidance=guidance_vec, + timestep_index=user_step, + total_num_timesteps=total_steps, + controlnet_double_block_residuals=None, + controlnet_single_block_residuals=None, + ip_adapter_extensions=neg_ip_adapter_extensions, + regional_prompting_extension=neg_regional_prompting_extension, + ) + + if img_cond_seq is not None: + neg_pred = neg_pred[:, :original_seq_len] + pred = neg_pred + step_cfg_scale * (pred - neg_pred) + + # Use scheduler.step() for the update + step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img) + img = step_output.prev_sample + + # Get t_prev for inpainting (next sigma value) + if step_index + 1 < len(scheduler.sigmas): + t_prev = scheduler.sigmas[step_index + 1].item() + else: + t_prev = 0.0 + + if inpaint_extension is not None: + img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev) + + # For Heun, only increment user step after second-order step completes + if is_heun: + if not in_first_order: + # Second order step completed + user_step += 1 + # Only call step_callback if we haven't exceeded total_steps + if user_step <= total_steps: + pbar.update(1) + preview_img = img - t_curr * pred + if inpaint_extension is not None: + preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents( + preview_img, 0.0 + ) + step_callback( + PipelineIntermediateState( + step=user_step, + order=2, + total_steps=total_steps, + timestep=int(t_curr * 1000), + latents=preview_img, + ), + ) + else: + # For LCM and other first-order schedulers + user_step += 1 + # Only call step_callback if we haven't exceeded total_steps + # (LCM scheduler may have more internal steps than user-facing steps) + if user_step <= total_steps: + pbar.update(1) + preview_img = img - t_curr * pred + if inpaint_extension is not None: + preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents( + preview_img, 0.0 + ) + step_callback( + PipelineIntermediateState( + step=user_step, + order=1, + total_steps=total_steps, + timestep=int(t_curr * 1000), + latents=preview_img, + ), + ) + + pbar.close() + return img + + # Original Euler implementation (when scheduler is None) + for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): + # DyPE: Update step state for timestep-dependent scaling + if dype_extension is not None and dype_embedder is not None: + dype_extension.update_step_state( + embedder=dype_embedder, + timestep=t_curr, + timestep_index=step_index, + total_steps=total_steps, + ) - # For Heun scheduler, track if we're in first or second order step - is_heun = hasattr(scheduler, "state_in_first_order") - in_first_order = scheduler.state_in_first_order if is_heun else True + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - # Run ControlNet models + # Run ControlNet models. controlnet_residuals: list[ControlNetFluxOutput] = [] for controlnet_extension in controlnet_extensions: controlnet_residuals.append( controlnet_extension.run_controlnet( - timestep_index=user_step, + timestep_index=step_index, total_num_timesteps=total_steps, img=img, img_ids=img_ids, @@ -109,17 +288,25 @@ def denoise( ) ) + # Merge the ControlNet residuals from multiple ControlNets. + # TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the + # controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same + # tensors. Calculating the sum materializes each tensor into its own instance. merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals) - # Prepare input for model + # Prepare input for model - concatenate fresh each step img_input = img img_input_ids = img_ids + # Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.) if img_cond is not None: img_input = torch.cat((img_input, img_cond), dim=-1) + # Add sequence-wise conditioning (for Kontext) if img_cond_seq is not None: - assert img_cond_seq_ids is not None + assert img_cond_seq_ids is not None, ( + "You need to provide either both or neither of the sequence conditioning" + ) img_input = torch.cat((img_input, img_cond_seq), dim=1) img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) @@ -131,7 +318,7 @@ def denoise( y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, - timestep_index=user_step, + timestep_index=step_index, total_num_timesteps=total_steps, controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, @@ -139,22 +326,33 @@ def denoise( regional_prompting_extension=pos_regional_prompting_extension, ) + # Slice prediction to only include the main image tokens if img_cond_seq is not None: pred = pred[:, :original_seq_len] - # Get CFG scale for current user step - step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)] + step_cfg_scale = cfg_scale[step_index] + # If step_cfg_scale, is 1.0, then we don't need to run the negative prediction. if not math.isclose(step_cfg_scale, 1.0): + # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance + # on systems with sufficient VRAM. + if neg_regional_prompting_extension is None: raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.") + # For negative prediction with Kontext, we need to include the reference images + # to maintain consistency between positive and negative passes. Without this, + # CFG would create artifacts as the attention mechanism would see different + # spatial structures in each pass neg_img_input = img neg_img_input_ids = img_ids + # Add channel-wise conditioning for negative pass if present if img_cond is not None: neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1) + # Add sequence-wise conditioning (Kontext) for negative pass + # This ensures reference images are processed consistently if img_cond_seq is not None: neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1) neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1) @@ -167,7 +365,7 @@ def denoise( y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, - timestep_index=user_step, + timestep_index=step_index, total_num_timesteps=total_steps, controlnet_double_block_residuals=None, controlnet_single_block_residuals=None, @@ -175,194 +373,31 @@ def denoise( regional_prompting_extension=neg_regional_prompting_extension, ) + # Slice negative prediction to match main image tokens if img_cond_seq is not None: neg_pred = neg_pred[:, :original_seq_len] pred = neg_pred + step_cfg_scale * (pred - neg_pred) - # Use scheduler.step() for the update - step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img) - img = step_output.prev_sample - - # Get t_prev for inpainting (next sigma value) - if step_index + 1 < len(scheduler.sigmas): - t_prev = scheduler.sigmas[step_index + 1].item() - else: - t_prev = 0.0 + preview_img = img - t_curr * pred + img = img + (t_prev - t_curr) * pred if inpaint_extension is not None: img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev) - - # For Heun, only increment user step after second-order step completes - if is_heun: - if not in_first_order: - # Second order step completed - user_step += 1 - # Only call step_callback if we haven't exceeded total_steps - if user_step <= total_steps: - pbar.update(1) - preview_img = img - t_curr * pred - if inpaint_extension is not None: - preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents( - preview_img, 0.0 - ) - step_callback( - PipelineIntermediateState( - step=user_step, - order=2, - total_steps=total_steps, - timestep=int(t_curr * 1000), - latents=preview_img, - ), - ) - else: - # For LCM and other first-order schedulers - user_step += 1 - # Only call step_callback if we haven't exceeded total_steps - # (LCM scheduler may have more internal steps than user-facing steps) - if user_step <= total_steps: - pbar.update(1) - preview_img = img - t_curr * pred - if inpaint_extension is not None: - preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0) - step_callback( - PipelineIntermediateState( - step=user_step, - order=1, - total_steps=total_steps, - timestep=int(t_curr * 1000), - latents=preview_img, - ), - ) - - pbar.close() - return img - - # Original Euler implementation (when scheduler is None) - for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) - - # Run ControlNet models. - controlnet_residuals: list[ControlNetFluxOutput] = [] - for controlnet_extension in controlnet_extensions: - controlnet_residuals.append( - controlnet_extension.run_controlnet( - timestep_index=step_index, - total_num_timesteps=total_steps, - img=img, - img_ids=img_ids, - txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, - txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, - y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, - timesteps=t_vec, - guidance=guidance_vec, - ) + preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0) + + step_callback( + PipelineIntermediateState( + step=step_index + 1, + order=1, + total_steps=total_steps, + timestep=int(t_curr), + latents=preview_img, + ), ) - # Merge the ControlNet residuals from multiple ControlNets. - # TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the - # controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same - # tensors. Calculating the sum materializes each tensor into its own instance. - merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals) - - # Prepare input for model - concatenate fresh each step - img_input = img - img_input_ids = img_ids - - # Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.) - if img_cond is not None: - img_input = torch.cat((img_input, img_cond), dim=-1) - - # Add sequence-wise conditioning (for Kontext) - if img_cond_seq is not None: - assert img_cond_seq_ids is not None, ( - "You need to provide either both or neither of the sequence conditioning" - ) - img_input = torch.cat((img_input, img_cond_seq), dim=1) - img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) - - pred = model( - img=img_input, - img_ids=img_input_ids, - txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, - txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, - y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, - timesteps=t_vec, - guidance=guidance_vec, - timestep_index=step_index, - total_num_timesteps=total_steps, - controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, - controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, - ip_adapter_extensions=pos_ip_adapter_extensions, - regional_prompting_extension=pos_regional_prompting_extension, - ) - - # Slice prediction to only include the main image tokens - if img_cond_seq is not None: - pred = pred[:, :original_seq_len] - - step_cfg_scale = cfg_scale[step_index] - - # If step_cfg_scale, is 1.0, then we don't need to run the negative prediction. - if not math.isclose(step_cfg_scale, 1.0): - # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance - # on systems with sufficient VRAM. - - if neg_regional_prompting_extension is None: - raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.") - - # For negative prediction with Kontext, we need to include the reference images - # to maintain consistency between positive and negative passes. Without this, - # CFG would create artifacts as the attention mechanism would see different - # spatial structures in each pass - neg_img_input = img - neg_img_input_ids = img_ids - - # Add channel-wise conditioning for negative pass if present - if img_cond is not None: - neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1) - - # Add sequence-wise conditioning (Kontext) for negative pass - # This ensures reference images are processed consistently - if img_cond_seq is not None: - neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1) - neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1) - - neg_pred = model( - img=neg_img_input, - img_ids=neg_img_input_ids, - txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings, - txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, - y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings, - timesteps=t_vec, - guidance=guidance_vec, - timestep_index=step_index, - total_num_timesteps=total_steps, - controlnet_double_block_residuals=None, - controlnet_single_block_residuals=None, - ip_adapter_extensions=neg_ip_adapter_extensions, - regional_prompting_extension=neg_regional_prompting_extension, - ) + return img - # Slice negative prediction to match main image tokens - if img_cond_seq is not None: - neg_pred = neg_pred[:, :original_seq_len] - pred = neg_pred + step_cfg_scale * (pred - neg_pred) - - preview_img = img - t_curr * pred - img = img + (t_prev - t_curr) * pred - - if inpaint_extension is not None: - img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev) - preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0) - - step_callback( - PipelineIntermediateState( - step=step_index + 1, - order=1, - total_steps=total_steps, - timestep=int(t_curr), - latents=preview_img, - ), - ) - - return img + finally: + # DyPE: Restore original position embedder + if original_pe_embedder is not None: + DyPEExtension.restore_model(model, original_pe_embedder) diff --git a/invokeai/backend/flux/dype/__init__.py b/invokeai/backend/flux/dype/__init__.py new file mode 100644 index 00000000000..488144201d8 --- /dev/null +++ b/invokeai/backend/flux/dype/__init__.py @@ -0,0 +1,18 @@ +"""Dynamic Position Extrapolation (DyPE) for FLUX models. + +DyPE enables high-resolution image generation (4K+) with pretrained FLUX models +by dynamically scaling RoPE position embeddings during the denoising process. + +Based on: https://github.com/wildminder/ComfyUI-DyPE +""" + +from invokeai.backend.flux.dype.base import DyPEConfig +from invokeai.backend.flux.dype.embed import DyPEEmbedND +from invokeai.backend.flux.dype.presets import DyPEPreset, get_dype_config_for_resolution + +__all__ = [ + "DyPEConfig", + "DyPEEmbedND", + "DyPEPreset", + "get_dype_config_for_resolution", +] diff --git a/invokeai/backend/flux/dype/base.py b/invokeai/backend/flux/dype/base.py new file mode 100644 index 00000000000..53f072efc7e --- /dev/null +++ b/invokeai/backend/flux/dype/base.py @@ -0,0 +1,226 @@ +"""DyPE base configuration and utilities.""" + +import math +from dataclasses import dataclass +from typing import Literal + +import torch +from torch import Tensor + + +@dataclass +class DyPEConfig: + """Configuration for Dynamic Position Extrapolation.""" + + enable_dype: bool = True + base_resolution: int = 1024 # Native training resolution + method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn" + dype_scale: float = 2.0 # Magnitude λs (0.0-8.0) + dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0) + dype_start_sigma: float = 1.0 # When DyPE decay starts + + +def get_mscale(scale: float, mscale_factor: float = 1.0) -> float: + """Calculate magnitude scaling factor. + + Args: + scale: The resolution scaling factor + mscale_factor: Adjustment factor for the scaling + + Returns: + The magnitude scaling factor + """ + if scale <= 1.0: + return 1.0 + return mscale_factor * math.log(scale) + 1.0 + + +def get_timestep_mscale( + scale: float, + current_sigma: float, + dype_scale: float, + dype_exponent: float, + dype_start_sigma: float, +) -> float: + """Calculate timestep-dependent magnitude scaling. + + The key insight of DyPE: early steps focus on low frequencies (global structure), + late steps on high frequencies (details). This function modulates the scaling + based on the current timestep/sigma. + + Args: + scale: Resolution scaling factor + current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) + dype_scale: DyPE magnitude (λs) + dype_exponent: DyPE decay speed (λt) + dype_start_sigma: Sigma threshold to start decay + + Returns: + Timestep-modulated scaling factor + """ + if scale <= 1.0: + return 1.0 + + # Normalize sigma to [0, 1] range relative to start_sigma + if current_sigma >= dype_start_sigma: + t_normalized = 1.0 + else: + t_normalized = current_sigma / dype_start_sigma + + # Apply exponential decay: stronger extrapolation early, weaker late + # decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late + decay = math.exp(-dype_exponent * (1.0 - t_normalized)) + + # Base mscale from resolution + base_mscale = get_mscale(scale) + + # Interpolate between base_mscale and 1.0 based on decay and dype_scale + # When decay=1 (early): use scaled value + # When decay=0 (late): use base value + scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay + + return scaled_mscale + + +def compute_vision_yarn_freqs( + pos: Tensor, + dim: int, + theta: int, + scale_h: float, + scale_w: float, + current_sigma: float, + dype_config: DyPEConfig, +) -> tuple[Tensor, Tensor]: + """Compute RoPE frequencies using NTK-aware scaling for high-resolution. + + This method extends FLUX's position encoding to handle resolutions beyond + the 1024px training resolution by scaling the base frequency (theta). + + The NTK-aware approach smoothly interpolates frequencies to cover larger + position ranges without breaking the attention patterns. + + Args: + pos: Position tensor + dim: Embedding dimension + theta: RoPE base frequency + scale_h: Height scaling factor + scale_w: Width scaling factor + current_sigma: Current noise level (reserved for future timestep-aware scaling) + dype_config: DyPE configuration + + Returns: + Tuple of (cos, sin) frequency tensors + """ + assert dim % 2 == 0 + + # Use the larger scale for NTK calculation + scale = max(scale_h, scale_w) + + device = pos.device + dtype = torch.float64 if device.type != "mps" else torch.float32 + + # NTK-aware theta scaling: extends position coverage for high-res + # Formula: theta_scaled = theta * scale^(dim/(dim-2)) + # This increases the wavelength of position encodings proportionally + if scale > 1.0: + ntk_alpha = scale ** (dim / (dim - 2)) + scaled_theta = theta * ntk_alpha + else: + scaled_theta = theta + + # Standard RoPE frequency computation + freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim + freqs = 1.0 / (scaled_theta**freq_seq) + + # Compute angles = position * frequency + angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + return cos.to(pos.dtype), sin.to(pos.dtype) + + +def compute_yarn_freqs( + pos: Tensor, + dim: int, + theta: int, + scale: float, + current_sigma: float, + dype_config: DyPEConfig, +) -> tuple[Tensor, Tensor]: + """Compute RoPE frequencies using YARN/NTK method. + + Uses NTK-aware theta scaling for high-resolution support. + + Args: + pos: Position tensor + dim: Embedding dimension + theta: RoPE base frequency + scale: Uniform scaling factor + current_sigma: Current noise level (reserved for future use) + dype_config: DyPE configuration + + Returns: + Tuple of (cos, sin) frequency tensors + """ + assert dim % 2 == 0 + + device = pos.device + dtype = torch.float64 if device.type != "mps" else torch.float32 + + # NTK-aware theta scaling + if scale > 1.0: + ntk_alpha = scale ** (dim / (dim - 2)) + scaled_theta = theta * ntk_alpha + else: + scaled_theta = theta + + freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim + freqs = 1.0 / (scaled_theta**freq_seq) + + angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + return cos.to(pos.dtype), sin.to(pos.dtype) + + +def compute_ntk_freqs( + pos: Tensor, + dim: int, + theta: int, + scale: float, +) -> tuple[Tensor, Tensor]: + """Compute RoPE frequencies using NTK method. + + Neural Tangent Kernel approach - continuous frequency scaling without + timestep dependency. + + Args: + pos: Position tensor + dim: Embedding dimension + theta: RoPE base frequency + scale: Scaling factor + + Returns: + Tuple of (cos, sin) frequency tensors + """ + assert dim % 2 == 0 + + device = pos.device + dtype = torch.float64 if device.type != "mps" else torch.float32 + + # NTK scaling + scaled_theta = theta * (scale ** (dim / (dim - 2))) + + freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim + freqs = 1.0 / (scaled_theta**freq_seq) + + angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + return cos.to(pos.dtype), sin.to(pos.dtype) diff --git a/invokeai/backend/flux/dype/embed.py b/invokeai/backend/flux/dype/embed.py new file mode 100644 index 00000000000..ace6a56ab0f --- /dev/null +++ b/invokeai/backend/flux/dype/embed.py @@ -0,0 +1,116 @@ +"""DyPE-enhanced position embedding module.""" + +import torch +from torch import Tensor, nn + +from invokeai.backend.flux.dype.base import DyPEConfig +from invokeai.backend.flux.dype.rope import rope_dype + + +class DyPEEmbedND(nn.Module): + """N-dimensional position embedding with DyPE support. + + This class replaces the standard EmbedND from FLUX with a DyPE-aware version + that dynamically scales position embeddings based on resolution and timestep. + + The key difference from EmbedND: + - Maintains step state (current_sigma, target dimensions) + - Uses rope_dype() instead of rope() for frequency computation + - Applies timestep-dependent scaling for better high-resolution generation + """ + + def __init__( + self, + dim: int, + theta: int, + axes_dim: list[int], + dype_config: DyPEConfig, + ): + """Initialize DyPE position embedder. + + Args: + dim: Total embedding dimension (sum of axes_dim) + theta: RoPE base frequency + axes_dim: Dimension allocation per axis (e.g., [16, 56, 56] for FLUX) + dype_config: DyPE configuration + """ + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + self.dype_config = dype_config + + # Step state - updated before each denoising step + self._current_sigma: float = 1.0 + self._target_height: int = 1024 + self._target_width: int = 1024 + + def set_step_state(self, sigma: float, height: int, width: int) -> None: + """Update the step state before each denoising step. + + This method should be called by the DyPE extension before each step + to update the current noise level and target dimensions. + + Args: + sigma: Current noise level (timestep value, 1.0 = full noise) + height: Target image height in pixels + width: Target image width in pixels + """ + self._current_sigma = sigma + self._target_height = height + self._target_width = width + + def forward(self, ids: Tensor) -> Tensor: + """Compute position embeddings with DyPE scaling. + + Args: + ids: Position indices tensor with shape (batch, seq_len, n_axes) + For FLUX: n_axes=3 (time/channel, height, width) + + Returns: + Position embedding tensor with shape (batch, 1, seq_len, dim) + """ + n_axes = ids.shape[-1] + + # Compute RoPE for each axis with DyPE scaling + embeddings = [] + for i in range(n_axes): + axis_emb = rope_dype( + pos=ids[..., i], + dim=self.axes_dim[i], + theta=self.theta, + current_sigma=self._current_sigma, + target_height=self._target_height, + target_width=self._target_width, + dype_config=self.dype_config, + ) + embeddings.append(axis_emb) + + # Concatenate embeddings from all axes + emb = torch.cat(embeddings, dim=-3) + + return emb.unsqueeze(1) + + @classmethod + def from_embednd( + cls, + embed_nd: nn.Module, + dype_config: DyPEConfig, + ) -> "DyPEEmbedND": + """Create a DyPEEmbedND from an existing EmbedND. + + This is a convenience method for patching an existing FLUX model. + + Args: + embed_nd: Original EmbedND module from FLUX + dype_config: DyPE configuration + + Returns: + New DyPEEmbedND with same parameters + """ + return cls( + dim=embed_nd.dim, + theta=embed_nd.theta, + axes_dim=embed_nd.axes_dim, + dype_config=dype_config, + ) diff --git a/invokeai/backend/flux/dype/presets.py b/invokeai/backend/flux/dype/presets.py new file mode 100644 index 00000000000..5071b51f9ba --- /dev/null +++ b/invokeai/backend/flux/dype/presets.py @@ -0,0 +1,141 @@ +"""DyPE presets and automatic configuration.""" + +from dataclasses import dataclass +from enum import Enum + +from invokeai.backend.flux.dype.base import DyPEConfig + + +class DyPEPreset(str, Enum): + """Predefined DyPE configurations.""" + + OFF = "off" # DyPE disabled + AUTO = "auto" # Automatically enable based on resolution + PRESET_4K = "4k" # Optimized for 3840x2160 / 4096x2160 + + +@dataclass +class DyPEPresetConfig: + """Preset configuration values.""" + + base_resolution: int + method: str + dype_scale: float + dype_exponent: float + dype_start_sigma: float + + +# Predefined preset configurations +DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = { + DyPEPreset.PRESET_4K: DyPEPresetConfig( + base_resolution=1024, + method="vision_yarn", + dype_scale=2.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ), +} + + +def get_dype_config_for_resolution( + width: int, + height: int, + base_resolution: int = 1024, + activation_threshold: int = 1536, +) -> DyPEConfig | None: + """Automatically determine DyPE config based on target resolution. + + FLUX can handle resolutions up to ~1.5x natively without significant artifacts. + DyPE is only activated when the resolution exceeds the activation threshold. + + Args: + width: Target image width in pixels + height: Target image height in pixels + base_resolution: Native training resolution of the model (for scale calculation) + activation_threshold: Resolution threshold above which DyPE is activated + + Returns: + DyPEConfig if DyPE should be enabled, None otherwise + """ + max_dim = max(width, height) + + if max_dim <= activation_threshold: + return None # FLUX can handle this natively + + # Calculate scaling factor based on base_resolution + scale = max_dim / base_resolution + + # Dynamic parameters based on scaling + # Higher resolution = higher dype_scale, capped at 8.0 + dynamic_dype_scale = min(2.0 * scale, 8.0) + + return DyPEConfig( + enable_dype=True, + base_resolution=base_resolution, + method="vision_yarn", + dype_scale=dynamic_dype_scale, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + + +def get_dype_config_from_preset( + preset: DyPEPreset, + width: int, + height: int, + custom_scale: float | None = None, + custom_exponent: float | None = None, +) -> DyPEConfig | None: + """Get DyPE configuration from a preset or custom values. + + Args: + preset: The DyPE preset to use + width: Target image width + height: Target image height + custom_scale: Optional custom dype_scale (overrides preset) + custom_exponent: Optional custom dype_exponent (overrides preset) + + Returns: + DyPEConfig if DyPE should be enabled, None otherwise + """ + if preset == DyPEPreset.OFF: + # Check if custom values are provided even with preset=OFF + if custom_scale is not None: + return DyPEConfig( + enable_dype=True, + base_resolution=1024, + method="vision_yarn", + dype_scale=custom_scale, + dype_exponent=custom_exponent if custom_exponent is not None else 2.0, + dype_start_sigma=1.0, + ) + return None + + if preset == DyPEPreset.AUTO: + config = get_dype_config_for_resolution( + width=width, + height=height, + base_resolution=1024, + activation_threshold=1536, + ) + # Apply custom overrides if provided + if config is not None: + if custom_scale is not None: + config.dype_scale = custom_scale + if custom_exponent is not None: + config.dype_exponent = custom_exponent + return config + + # Use preset configuration + preset_config = DYPE_PRESETS.get(preset) + if preset_config is None: + return None + + return DyPEConfig( + enable_dype=True, + base_resolution=preset_config.base_resolution, + method=preset_config.method, + dype_scale=custom_scale if custom_scale is not None else preset_config.dype_scale, + dype_exponent=custom_exponent if custom_exponent is not None else preset_config.dype_exponent, + dype_start_sigma=preset_config.dype_start_sigma, + ) diff --git a/invokeai/backend/flux/dype/rope.py b/invokeai/backend/flux/dype/rope.py new file mode 100644 index 00000000000..f6a1594f6be --- /dev/null +++ b/invokeai/backend/flux/dype/rope.py @@ -0,0 +1,110 @@ +"""DyPE-enhanced RoPE (Rotary Position Embedding) functions.""" + +import torch +from einops import rearrange +from torch import Tensor + +from invokeai.backend.flux.dype.base import ( + DyPEConfig, + compute_ntk_freqs, + compute_vision_yarn_freqs, + compute_yarn_freqs, +) + + +def rope_dype( + pos: Tensor, + dim: int, + theta: int, + current_sigma: float, + target_height: int, + target_width: int, + dype_config: DyPEConfig, +) -> Tensor: + """Compute RoPE with Dynamic Position Extrapolation. + + This is the core DyPE function that replaces the standard rope() function. + It applies resolution-aware and timestep-aware scaling to position embeddings. + + Args: + pos: Position indices tensor + dim: Embedding dimension per axis + theta: RoPE base frequency (typically 10000) + current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) + target_height: Target image height in pixels + target_width: Target image width in pixels + dype_config: DyPE configuration + + Returns: + Rotary position embedding tensor with shape suitable for FLUX attention + """ + assert dim % 2 == 0 + + # Calculate scaling factors + base_res = dype_config.base_resolution + scale_h = target_height / base_res + scale_w = target_width / base_res + scale = max(scale_h, scale_w) + + # If no scaling needed and DyPE disabled, use base method + if not dype_config.enable_dype or scale <= 1.0: + return _rope_base(pos, dim, theta) + + # Select method and compute frequencies + method = dype_config.method + + if method == "vision_yarn": + cos, sin = compute_vision_yarn_freqs( + pos=pos, + dim=dim, + theta=theta, + scale_h=scale_h, + scale_w=scale_w, + current_sigma=current_sigma, + dype_config=dype_config, + ) + elif method == "yarn": + cos, sin = compute_yarn_freqs( + pos=pos, + dim=dim, + theta=theta, + scale=scale, + current_sigma=current_sigma, + dype_config=dype_config, + ) + elif method == "ntk": + cos, sin = compute_ntk_freqs( + pos=pos, + dim=dim, + theta=theta, + scale=scale, + ) + else: # "base" + return _rope_base(pos, dim, theta) + + # Construct rotation matrix from cos/sin + # Output shape: (batch, seq_len, dim/2, 2, 2) + out = torch.stack([cos, -sin, sin, cos], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + + return out.to(dtype=pos.dtype, device=pos.device) + + +def _rope_base(pos: Tensor, dim: int, theta: int) -> Tensor: + """Standard RoPE without DyPE scaling. + + This matches the original rope() function from invokeai.backend.flux.math. + """ + assert dim % 2 == 0 + + device = pos.device + dtype = torch.float64 if device.type != "mps" else torch.float32 + + scale = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim + omega = 1.0 / (theta**scale) + + out = torch.einsum("...n,d->...nd", pos.to(dtype), omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + + return out.to(dtype=pos.dtype, device=pos.device) diff --git a/invokeai/backend/flux/extensions/dype_extension.py b/invokeai/backend/flux/extensions/dype_extension.py new file mode 100644 index 00000000000..db27c053dd3 --- /dev/null +++ b/invokeai/backend/flux/extensions/dype_extension.py @@ -0,0 +1,91 @@ +"""DyPE extension for FLUX denoising pipeline.""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from invokeai.backend.flux.dype.base import DyPEConfig +from invokeai.backend.flux.dype.embed import DyPEEmbedND + +if TYPE_CHECKING: + from invokeai.backend.flux.model import Flux + + +@dataclass +class DyPEExtension: + """Extension for Dynamic Position Extrapolation in FLUX models. + + This extension manages the patching of the FLUX model's position embedder + and updates the step state during denoising. + + Usage: + 1. Create extension with config and target dimensions + 2. Call patch_model() to replace pe_embedder with DyPE version + 3. Call update_step_state() before each denoising step + 4. Call restore_model() after denoising to restore original embedder + """ + + config: DyPEConfig + target_height: int + target_width: int + + def patch_model(self, model: "Flux") -> tuple[DyPEEmbedND, object]: + """Patch the model's position embedder with DyPE version. + + Args: + model: The FLUX model to patch + + Returns: + Tuple of (new DyPE embedder, original embedder for restoration) + """ + original_embedder = model.pe_embedder + + dype_embedder = DyPEEmbedND.from_embednd( + embed_nd=original_embedder, + dype_config=self.config, + ) + + # Set initial state + dype_embedder.set_step_state( + sigma=1.0, + height=self.target_height, + width=self.target_width, + ) + + # Replace the embedder + model.pe_embedder = dype_embedder + + return dype_embedder, original_embedder + + def update_step_state( + self, + embedder: DyPEEmbedND, + timestep: float, + timestep_index: int, + total_steps: int, + ) -> None: + """Update the step state in the DyPE embedder. + + This should be called before each denoising step to update the + current noise level for timestep-dependent scaling. + + Args: + embedder: The DyPE embedder to update + timestep: Current timestep value (sigma/noise level) + timestep_index: Current step index (0-based) + total_steps: Total number of denoising steps + """ + embedder.set_step_state( + sigma=timestep, + height=self.target_height, + width=self.target_width, + ) + + @staticmethod + def restore_model(model: "Flux", original_embedder: object) -> None: + """Restore the original position embedder. + + Args: + model: The FLUX model to restore + original_embedder: The original embedder saved from patch_model() + """ + model.pe_embedder = original_embedder diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index c7aaf800681..71c9f2e5c6f 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -835,6 +835,7 @@ "cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)", "clipSkip": "$t(parameters.clipSkip)", "createdBy": "Created By", + "dypePreset": "DyPE Preset", "generationMode": "Generation Mode", "guidance": "Guidance", "height": "Height", @@ -1374,6 +1375,7 @@ "scaledHeight": "Scaled H", "scaledWidth": "Scaled W", "scheduler": "Scheduler", + "dypePreset": "DyPE (High-Res)", "seamlessXAxis": "Seamless X Axis", "seamlessYAxis": "Seamless Y Axis", "colorCompensation": "Color Compensation", @@ -1610,6 +1612,13 @@ "Each scheduler defines how to iteratively add noise to an image or how to update a sample based on a model's output." ] }, + "fluxDypePreset": { + "heading": "DyPE (High-Resolution)", + "paragraphs": [ + "Dynamic Position Extrapolation (DyPE) improves FLUX generation quality at resolutions above the training size (1024px).", + "Off: Standard generation. Auto: Automatically enables for images > 1536px. 4K: Optimized settings for 4K resolution output." + ] + }, "seedVarianceEnhancer": { "heading": "Seed Variance Enhancer", "paragraphs": [ diff --git a/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts b/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts index fc045356e0c..fddb7b4439d 100644 --- a/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts +++ b/invokeai/frontend/web/src/common/components/InformationalPopover/constants.ts @@ -3,6 +3,7 @@ import denoisingStrength from 'public/assets/images/denoising-strength.png'; export type Feature = | 'clipSkip' + | 'fluxDypePreset' | 'hrf' | 'paramNegativeConditioning' | 'paramPositiveConditioning' @@ -88,6 +89,9 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = { clipSkip: { href: 'https://support.invoke.ai/support/solutions/articles/151000178161-advanced-settings', }, + fluxDypePreset: { + placement: 'right', + }, inpainting: { href: 'https://support.invoke.ai/support/solutions/articles/151000096702-inpainting-outpainting-and-bounding-box', }, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 0190aba602b..58efe225edb 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -72,6 +72,9 @@ const slice = createSlice({ setFluxScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => { state.fluxScheduler = action.payload; }, + setFluxDypePreset: (state, action: PayloadAction<'off' | 'auto' | '4k'>) => { + state.fluxDypePreset = action.payload; + }, setZImageScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => { state.zImageScheduler = action.payload; }, @@ -465,6 +468,7 @@ export const { setGuidance, setScheduler, setFluxScheduler, + setFluxDypePreset, setZImageScheduler, setZImageSeedVarianceEnabled, setZImageSeedVarianceStrength, @@ -609,6 +613,7 @@ export const selectModelSupportsOptimizedDenoising = createSelector( ); export const selectScheduler = createParamsSelector((params) => params.scheduler); export const selectFluxScheduler = createParamsSelector((params) => params.fluxScheduler); +export const selectFluxDypePreset = createParamsSelector((params) => params.fluxDypePreset); export const selectZImageScheduler = createParamsSelector((params) => params.zImageScheduler); export const selectZImageSeedVarianceEnabled = createParamsSelector((params) => params.zImageSeedVarianceEnabled); export const selectZImageSeedVarianceStrength = createParamsSelector((params) => params.zImageSeedVarianceStrength); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index fbd3d415b79..13b417bc9d7 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -9,6 +9,7 @@ import { zParameterCLIPGEmbedModel, zParameterCLIPLEmbedModel, zParameterControlLoRAModel, + zParameterFluxDypePreset, zParameterFluxScheduler, zParameterGuidance, zParameterImageDimension, @@ -599,6 +600,7 @@ export const zParamsState = z.object({ iterations: z.number(), scheduler: zParameterScheduler, fluxScheduler: zParameterFluxScheduler, + fluxDypePreset: zParameterFluxDypePreset, zImageScheduler: zParameterZImageScheduler, upscaleScheduler: zParameterScheduler, upscaleCfgScale: zParameterCFGScale, @@ -659,6 +661,7 @@ export const getInitialParamsState = (): ParamsState => ({ iterations: 1, scheduler: 'dpmpp_3m_k', fluxScheduler: 'euler', + fluxDypePreset: 'off', zImageScheduler: 'euler', upscaleScheduler: 'kdpm_2', upscaleCfgScale: 2, diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index d9201f15ffa..c3efb22d885 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -16,6 +16,7 @@ import { setCfgRescaleMultiplier, setCfgScale, setClipSkip, + setFluxDypePreset, setFluxScheduler, setGuidance, setImg2imgStrength, @@ -51,6 +52,7 @@ import type { ParameterCFGRescaleMultiplier, ParameterCFGScale, ParameterCLIPSkip, + ParameterFluxDypePreset, ParameterGuidance, ParameterHeight, ParameterModel, @@ -74,6 +76,7 @@ import { zParameterCFGRescaleMultiplier, zParameterCFGScale, zParameterCLIPSkip, + zParameterFluxDypePreset, zParameterGuidance, zParameterImageDimension, zParameterNegativePrompt, @@ -368,6 +371,26 @@ const Guidance: SingleMetadataHandler = { }; //#endregion Guidance +//#region FluxDypePreset +const FluxDypePreset: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'FluxDypePreset', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'dype_preset'); + const parsed = zParameterFluxDypePreset.parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(setFluxDypePreset(value)); + }, + i18nKey: 'metadata.dypePreset', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; +//#endregion FluxDypePreset + //#region Scheduler const Scheduler: SingleMetadataHandler = { [SingleMetadataKey]: true, @@ -1064,6 +1087,7 @@ export const ImageMetadataHandlers = { CFGRescaleMultiplier, CLIPSkip, Guidance, + FluxDypePreset, Width, Height, Seed, diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 994d9091431..ed4a2c41012 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -70,6 +70,9 @@ export const zFluxSchedulerField = z.enum(['euler', 'heun', 'lcm']); // Z-Image scheduler options (Flow Matching schedulers, same as Flux) export const zZImageSchedulerField = z.enum(['euler', 'heun', 'lcm']); + +// Flux DyPE (Dynamic Position Extrapolation) preset options for high-resolution generation +export const zFluxDypePresetField = z.enum(['off', 'auto', '4k']); // #endregion // #region Model-related schemas diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index d54697cf559..a9321cb6e42 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -43,7 +43,15 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise 1536px)' }, + { value: '4k', label: '4K Optimized' }, +]; + +const ParamFluxDypePreset = () => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const fluxDypePreset = useAppSelector(selectFluxDypePreset); + + const onChange = useCallback( + (v) => { + if (!isParameterFluxDypePreset(v?.value)) { + return; + } + dispatch(setFluxDypePreset(v.value)); + }, + [dispatch] + ); + + const value = useMemo(() => FLUX_DYPE_PRESET_OPTIONS.find((o) => o.value === fluxDypePreset), [fluxDypePreset]); + + return ( + + + {t('parameters.dypePreset')} + + + + ); +}; + +export default memo(ParamFluxDypePreset); diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index fb059dec2d5..6cfd4b3fa58 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -2,6 +2,7 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; import { buildZodTypeGuard } from 'common/util/zodUtils'; import { + zFluxDypePresetField, zFluxSchedulerField, zModelIdentifierField, zSchedulerField, @@ -76,6 +77,11 @@ export const [zParameterZImageScheduler, isParameterZImageScheduler] = buildPara export type ParameterZImageScheduler = z.infer; // #endregion +// #region Flux DyPE Preset +export const [zParameterFluxDypePreset, isParameterFluxDypePreset] = buildParameter(zFluxDypePresetField); +export type ParameterFluxDypePreset = z.infer; +// #endregion + // #region seed export const [zParameterSeed, isParameterSeed] = buildParameter(z.number().int().min(0).max(NUMPY_RAND_MAX)); export type ParameterSeed = z.infer; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 90ad40c7cfa..40a832b5232 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -8,6 +8,7 @@ import { selectIsCogView4, selectIsFLUX, selectIsSD3, selectIsZImage } from 'fea import { LoRAList } from 'features/lora/components/LoRAList'; import LoRASelect from 'features/lora/components/LoRASelect'; import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; +import ParamFluxDypePreset from 'features/parameters/components/Core/ParamFluxDypePreset'; import ParamFluxScheduler from 'features/parameters/components/Core/ParamFluxScheduler'; import ParamGuidance from 'features/parameters/components/Core/ParamGuidance'; import ParamScheduler from 'features/parameters/components/Core/ParamScheduler'; @@ -72,6 +73,7 @@ export const GenerationSettingsAccordion = memo(() => { {!isFLUX && !isSD3 && !isCogView4 && !isZImage && } {isFLUX && } + {isFLUX && } {isZImage && } {isFLUX && modelConfig && !isFluxFillMainModelModelConfig(modelConfig) && } diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 829c5d435cd..ee5aa1ad55c 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -6990,6 +6990,12 @@ export type components = { */ download_path: string; }; + /** + * DyPEPreset + * @description Predefined DyPE configurations. + * @enum {string} + */ + DyPEPreset: "off" | "auto" | "4k"; /** * Dynamic Prompt * @description Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator @@ -8365,6 +8371,23 @@ export type components = { * @default null */ kontext_conditioning?: components["schemas"]["FluxKontextConditioningField"] | components["schemas"]["FluxKontextConditioningField"][] | null; + /** + * @description DyPE preset for high-resolution generation. 'auto' enables automatically for resolutions > 1536px. '4k' uses optimized settings for 4K output. + * @default off + */ + dype_preset?: components["schemas"]["DyPEPreset"]; + /** + * Dype Scale + * @description DyPE magnitude (λs). Higher values = stronger extrapolation. Only used when dype_preset is not 'off'. + * @default null + */ + dype_scale?: number | null; + /** + * Dype Exponent + * @description DyPE decay speed (λt). Controls transition from low to high frequency detail. Only used when dype_preset is not 'off'. + * @default null + */ + dype_exponent?: number | null; /** * type * @default flux_denoise @@ -8540,6 +8563,23 @@ export type components = { * @default null */ kontext_conditioning?: components["schemas"]["FluxKontextConditioningField"] | components["schemas"]["FluxKontextConditioningField"][] | null; + /** + * @description DyPE preset for high-resolution generation. 'auto' enables automatically for resolutions > 1536px. '4k' uses optimized settings for 4K output. + * @default off + */ + dype_preset?: components["schemas"]["DyPEPreset"]; + /** + * Dype Scale + * @description DyPE magnitude (λs). Higher values = stronger extrapolation. Only used when dype_preset is not 'off'. + * @default null + */ + dype_scale?: number | null; + /** + * Dype Exponent + * @description DyPE decay speed (λt). Controls transition from low to high frequency detail. Only used when dype_preset is not 'off'. + * @default null + */ + dype_exponent?: number | null; /** * type * @default flux_denoise_meta diff --git a/tests/backend/flux/dype/test_dype.py b/tests/backend/flux/dype/test_dype.py new file mode 100644 index 00000000000..1c027e298f6 --- /dev/null +++ b/tests/backend/flux/dype/test_dype.py @@ -0,0 +1,358 @@ +"""Tests for DyPE (Dynamic Position Extrapolation) module.""" + +import torch + +from invokeai.backend.flux.dype.base import ( + DyPEConfig, + compute_ntk_freqs, + compute_vision_yarn_freqs, + compute_yarn_freqs, + get_mscale, + get_timestep_mscale, +) +from invokeai.backend.flux.dype.embed import DyPEEmbedND +from invokeai.backend.flux.dype.presets import ( + DYPE_PRESETS, + DyPEPreset, + get_dype_config_for_resolution, + get_dype_config_from_preset, +) +from invokeai.backend.flux.dype.rope import rope_dype + + +class TestDyPEConfig: + """Tests for DyPEConfig dataclass.""" + + def test_default_values(self): + config = DyPEConfig() + assert config.enable_dype is True + assert config.base_resolution == 1024 + assert config.method == "vision_yarn" + assert config.dype_scale == 2.0 + assert config.dype_exponent == 2.0 + assert config.dype_start_sigma == 1.0 + + def test_custom_values(self): + config = DyPEConfig( + enable_dype=False, + base_resolution=512, + method="yarn", + dype_scale=4.0, + dype_exponent=3.0, + dype_start_sigma=0.5, + ) + assert config.enable_dype is False + assert config.base_resolution == 512 + assert config.method == "yarn" + assert config.dype_scale == 4.0 + + +class TestMscale: + """Tests for mscale calculation functions.""" + + def test_get_mscale_no_scaling(self): + """When scale <= 1.0, mscale should be 1.0.""" + assert get_mscale(1.0) == 1.0 + assert get_mscale(0.5) == 1.0 + + def test_get_mscale_with_scaling(self): + """When scale > 1.0, mscale should increase.""" + mscale_2x = get_mscale(2.0) + mscale_4x = get_mscale(4.0) + + assert mscale_2x > 1.0 + assert mscale_4x > mscale_2x + + def test_get_timestep_mscale_no_scaling(self): + """When scale <= 1.0, timestep_mscale should be 1.0.""" + result = get_timestep_mscale( + scale=1.0, + current_sigma=0.5, + dype_scale=2.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + assert result == 1.0 + + def test_get_timestep_mscale_high_sigma(self): + """Early steps (high sigma) should have stronger scaling.""" + early_mscale = get_timestep_mscale( + scale=2.0, + current_sigma=1.0, # Early step + dype_scale=2.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + late_mscale = get_timestep_mscale( + scale=2.0, + current_sigma=0.1, # Late step + dype_scale=2.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + + # Early steps should have larger mscale than late steps + assert early_mscale >= late_mscale + + +class TestRopeDype: + """Tests for DyPE-enhanced RoPE function.""" + + def test_rope_dype_shape(self): + """Test that rope_dype returns correct shape.""" + pos = torch.zeros(1, 64) + dim = 64 + theta = 10000 + + config = DyPEConfig() + result = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=0.5, + target_height=2048, + target_width=2048, + dype_config=config, + ) + + # Shape should be (batch, seq_len, dim/2, 2, 2) + assert result.shape == (1, 64, dim // 2, 2, 2) + + def test_rope_dype_no_scaling(self): + """When target is same as base, output should match base rope.""" + pos = torch.arange(16).unsqueeze(0).float() + dim = 32 + theta = 10000 + + config = DyPEConfig(base_resolution=1024) + + # No scaling needed + result_no_scale = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=0.5, + target_height=1024, + target_width=1024, + dype_config=config, + ) + + # With scaling + result_with_scale = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=0.5, + target_height=2048, + target_width=2048, + dype_config=config, + ) + + # Results should be different when scaling is applied + assert not torch.allclose(result_no_scale, result_with_scale) + + +class TestDyPEEmbedND: + """Tests for DyPEEmbedND module.""" + + def test_init(self): + """Test DyPEEmbedND initialization.""" + config = DyPEConfig() + embedder = DyPEEmbedND( + dim=128, + theta=10000, + axes_dim=[16, 56, 56], + dype_config=config, + ) + + assert embedder.dim == 128 + assert embedder.theta == 10000 + assert embedder.axes_dim == [16, 56, 56] + + def test_set_step_state(self): + """Test step state update.""" + config = DyPEConfig() + embedder = DyPEEmbedND( + dim=128, + theta=10000, + axes_dim=[16, 56, 56], + dype_config=config, + ) + + embedder.set_step_state(sigma=0.5, height=2048, width=2048) + + assert embedder._current_sigma == 0.5 + assert embedder._target_height == 2048 + assert embedder._target_width == 2048 + + def test_forward_shape(self): + """Test forward pass output shape.""" + config = DyPEConfig() + embedder = DyPEEmbedND( + dim=128, + theta=10000, + axes_dim=[16, 56, 56], + dype_config=config, + ) + + # Create input ids tensor (batch=1, seq_len=64, n_axes=3) + ids = torch.zeros(1, 64, 3) + + result = embedder(ids) + + # Output should have shape (batch, 1, seq_len, dim) + # Actually the shape is (batch, 1, seq_len, dim/2, 2, 2) based on rope output + assert result.dim() == 6 + assert result.shape[0] == 1 # batch + assert result.shape[1] == 1 # unsqueeze + assert result.shape[2] == 64 # seq_len + + +class TestDyPEPresets: + """Tests for DyPE preset configurations.""" + + def test_preset_4k_exists(self): + """Test that 4K preset is defined.""" + assert DyPEPreset.PRESET_4K in DYPE_PRESETS + + def test_get_dype_config_for_resolution_below_threshold(self): + """When resolution is below threshold, should return None.""" + config = get_dype_config_for_resolution( + width=1024, + height=1024, + activation_threshold=1536, + ) + assert config is None + + config = get_dype_config_for_resolution( + width=1536, + height=1024, + activation_threshold=1536, + ) + assert config is None + + def test_get_dype_config_for_resolution_above_threshold(self): + """When resolution is above threshold, should return config.""" + config = get_dype_config_for_resolution( + width=2048, + height=2048, + activation_threshold=1536, + ) + assert config is not None + assert config.enable_dype is True + assert config.method == "vision_yarn" + + def test_get_dype_config_for_resolution_dynamic_scale(self): + """Higher resolution should result in higher dype_scale.""" + config_2k = get_dype_config_for_resolution( + width=2048, + height=2048, + base_resolution=1024, + activation_threshold=1536, + ) + config_4k = get_dype_config_for_resolution( + width=4096, + height=4096, + base_resolution=1024, + activation_threshold=1536, + ) + + assert config_2k is not None + assert config_4k is not None + assert config_4k.dype_scale > config_2k.dype_scale + + def test_get_dype_config_from_preset_off(self): + """Preset OFF should return None.""" + config = get_dype_config_from_preset( + preset=DyPEPreset.OFF, + width=2048, + height=2048, + ) + assert config is None + + def test_get_dype_config_from_preset_auto(self): + """Preset AUTO should use resolution-based config.""" + config = get_dype_config_from_preset( + preset=DyPEPreset.AUTO, + width=2048, + height=2048, + ) + assert config is not None + assert config.enable_dype is True + + def test_get_dype_config_from_preset_4k(self): + """Preset 4K should use 4K settings.""" + config = get_dype_config_from_preset( + preset=DyPEPreset.PRESET_4K, + width=3840, + height=2160, + ) + assert config is not None + assert config.enable_dype is True + + def test_get_dype_config_from_preset_custom_overrides(self): + """Custom scale/exponent should override preset values.""" + config = get_dype_config_from_preset( + preset=DyPEPreset.PRESET_4K, + width=3840, + height=2160, + custom_scale=5.0, + custom_exponent=10.0, + ) + assert config is not None + assert config.dype_scale == 5.0 + assert config.dype_exponent == 10.0 + + +class TestFrequencyComputation: + """Tests for frequency computation functions.""" + + def test_compute_vision_yarn_freqs_shape(self): + """Test vision_yarn frequency computation shape.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos, sin = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale_h=2.0, + scale_w=2.0, + current_sigma=0.5, + dype_config=config, + ) + + assert cos.shape == sin.shape + assert cos.shape[0] == 1 # batch + assert cos.shape[1] == 16 # seq_len + + def test_compute_yarn_freqs_shape(self): + """Test yarn frequency computation shape.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos, sin = compute_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale=2.0, + current_sigma=0.5, + dype_config=config, + ) + + assert cos.shape == sin.shape + assert cos.shape[0] == 1 + + def test_compute_ntk_freqs_shape(self): + """Test ntk frequency computation shape.""" + pos = torch.arange(16).unsqueeze(0).float() + + cos, sin = compute_ntk_freqs( + pos=pos, + dim=32, + theta=10000, + scale=2.0, + ) + + assert cos.shape == sin.shape + assert cos.shape[0] == 1