Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.dype.presets import DyPEPreset, get_dype_config_from_preset
from invokeai.backend.flux.extensions.dype_extension import DyPEExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
Expand Down Expand Up @@ -64,7 +66,7 @@
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.2.0",
version="4.3.0",
)
class FluxDenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX transformer model."""
Expand Down Expand Up @@ -166,6 +168,24 @@ class FluxDenoiseInvocation(BaseInvocation):
input=Input.Connection,
)

# DyPE (Dynamic Position Extrapolation) for high-resolution generation
dype_preset: DyPEPreset = InputField(
default=DyPEPreset.OFF,
description="DyPE preset for high-resolution generation. 'auto' enables automatically for resolutions > 1536px. '4k' uses optimized settings for 4K output.",
)
dype_scale: Optional[float] = InputField(
default=None,
ge=0.0,
le=8.0,
description="DyPE magnitude (λs). Higher values = stronger extrapolation. Only used when dype_preset is not 'off'.",
)
dype_exponent: Optional[float] = InputField(
default=None,
ge=0.0,
le=1000.0,
description="DyPE decay speed (λt). Controls transition from low to high frequency detail. Only used when dype_preset is not 'off'.",
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
Expand Down Expand Up @@ -422,6 +442,26 @@ def _run_diffusion(
kontext_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids

# Prepare DyPE extension for high-resolution generation
dype_extension: DyPEExtension | None = None
dype_config = get_dype_config_from_preset(
preset=self.dype_preset,
width=self.width,
height=self.height,
custom_scale=self.dype_scale,
custom_exponent=self.dype_exponent,
)
if dype_config is not None:
dype_extension = DyPEExtension(
config=dype_config,
target_height=self.height,
target_width=self.width,
)
context.logger.info(
f"DyPE enabled: {self.width}x{self.height}, preset={self.dype_preset.value}, "
f"scale={dype_config.dype_scale:.2f}, method={dype_config.method}"
)

x = denoise(
model=transformer,
img=x,
Expand All @@ -439,6 +479,7 @@ def _run_diffusion(
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
dype_extension=dype_extension,
scheduler=scheduler,
)

Expand Down
443 changes: 239 additions & 204 deletions invokeai/backend/flux/denoise.py

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions invokeai/backend/flux/dype/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Dynamic Position Extrapolation (DyPE) for FLUX models.

DyPE enables high-resolution image generation (4K+) with pretrained FLUX models
by dynamically scaling RoPE position embeddings during the denoising process.

Based on: https://github.com/wildminder/ComfyUI-DyPE
"""

from invokeai.backend.flux.dype.base import DyPEConfig
from invokeai.backend.flux.dype.embed import DyPEEmbedND
from invokeai.backend.flux.dype.presets import DyPEPreset, get_dype_config_for_resolution

__all__ = [
"DyPEConfig",
"DyPEEmbedND",
"DyPEPreset",
"get_dype_config_for_resolution",
]
226 changes: 226 additions & 0 deletions invokeai/backend/flux/dype/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""DyPE base configuration and utilities."""

import math
from dataclasses import dataclass
from typing import Literal

import torch
from torch import Tensor


@dataclass
class DyPEConfig:
"""Configuration for Dynamic Position Extrapolation."""

enable_dype: bool = True
base_resolution: int = 1024 # Native training resolution
method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn"
dype_scale: float = 2.0 # Magnitude λs (0.0-8.0)
dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0)
dype_start_sigma: float = 1.0 # When DyPE decay starts


def get_mscale(scale: float, mscale_factor: float = 1.0) -> float:
"""Calculate magnitude scaling factor.

Args:
scale: The resolution scaling factor
mscale_factor: Adjustment factor for the scaling

Returns:
The magnitude scaling factor
"""
if scale <= 1.0:
return 1.0
return mscale_factor * math.log(scale) + 1.0


def get_timestep_mscale(
scale: float,
current_sigma: float,
dype_scale: float,
dype_exponent: float,
dype_start_sigma: float,
) -> float:
"""Calculate timestep-dependent magnitude scaling.

The key insight of DyPE: early steps focus on low frequencies (global structure),
late steps on high frequencies (details). This function modulates the scaling
based on the current timestep/sigma.

Args:
scale: Resolution scaling factor
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
dype_scale: DyPE magnitude (λs)
dype_exponent: DyPE decay speed (λt)
dype_start_sigma: Sigma threshold to start decay

Returns:
Timestep-modulated scaling factor
"""
if scale <= 1.0:
return 1.0

# Normalize sigma to [0, 1] range relative to start_sigma
if current_sigma >= dype_start_sigma:
t_normalized = 1.0
else:
t_normalized = current_sigma / dype_start_sigma

# Apply exponential decay: stronger extrapolation early, weaker late
# decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
decay = math.exp(-dype_exponent * (1.0 - t_normalized))

# Base mscale from resolution
base_mscale = get_mscale(scale)

# Interpolate between base_mscale and 1.0 based on decay and dype_scale
# When decay=1 (early): use scaled value
# When decay=0 (late): use base value
scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay

return scaled_mscale


def compute_vision_yarn_freqs(
pos: Tensor,
dim: int,
theta: int,
scale_h: float,
scale_w: float,
current_sigma: float,
dype_config: DyPEConfig,
) -> tuple[Tensor, Tensor]:
"""Compute RoPE frequencies using NTK-aware scaling for high-resolution.

This method extends FLUX's position encoding to handle resolutions beyond
the 1024px training resolution by scaling the base frequency (theta).

The NTK-aware approach smoothly interpolates frequencies to cover larger
position ranges without breaking the attention patterns.

Args:
pos: Position tensor
dim: Embedding dimension
theta: RoPE base frequency
scale_h: Height scaling factor
scale_w: Width scaling factor
current_sigma: Current noise level (reserved for future timestep-aware scaling)
dype_config: DyPE configuration

Returns:
Tuple of (cos, sin) frequency tensors
"""
assert dim % 2 == 0

# Use the larger scale for NTK calculation
scale = max(scale_h, scale_w)

device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32

# NTK-aware theta scaling: extends position coverage for high-res
# Formula: theta_scaled = theta * scale^(dim/(dim-2))
# This increases the wavelength of position encodings proportionally
if scale > 1.0:
ntk_alpha = scale ** (dim / (dim - 2))
scaled_theta = theta * ntk_alpha
else:
scaled_theta = theta

# Standard RoPE frequency computation
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
freqs = 1.0 / (scaled_theta**freq_seq)

# Compute angles = position * frequency
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)

cos = torch.cos(angles)
sin = torch.sin(angles)

return cos.to(pos.dtype), sin.to(pos.dtype)


def compute_yarn_freqs(
pos: Tensor,
dim: int,
theta: int,
scale: float,
current_sigma: float,
dype_config: DyPEConfig,
) -> tuple[Tensor, Tensor]:
"""Compute RoPE frequencies using YARN/NTK method.

Uses NTK-aware theta scaling for high-resolution support.

Args:
pos: Position tensor
dim: Embedding dimension
theta: RoPE base frequency
scale: Uniform scaling factor
current_sigma: Current noise level (reserved for future use)
dype_config: DyPE configuration

Returns:
Tuple of (cos, sin) frequency tensors
"""
assert dim % 2 == 0

device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32

# NTK-aware theta scaling
if scale > 1.0:
ntk_alpha = scale ** (dim / (dim - 2))
scaled_theta = theta * ntk_alpha
else:
scaled_theta = theta

freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
freqs = 1.0 / (scaled_theta**freq_seq)

angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)

cos = torch.cos(angles)
sin = torch.sin(angles)

return cos.to(pos.dtype), sin.to(pos.dtype)


def compute_ntk_freqs(
pos: Tensor,
dim: int,
theta: int,
scale: float,
) -> tuple[Tensor, Tensor]:
"""Compute RoPE frequencies using NTK method.

Neural Tangent Kernel approach - continuous frequency scaling without
timestep dependency.

Args:
pos: Position tensor
dim: Embedding dimension
theta: RoPE base frequency
scale: Scaling factor

Returns:
Tuple of (cos, sin) frequency tensors
"""
assert dim % 2 == 0

device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32

# NTK scaling
scaled_theta = theta * (scale ** (dim / (dim - 2)))

freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
freqs = 1.0 / (scaled_theta**freq_seq)

angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)

cos = torch.cos(angles)
sin = torch.sin(angles)

return cos.to(pos.dtype), sin.to(pos.dtype)
Loading
Loading