-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Feature: Add DyPE (Dynamic Position Extrapolation) support to FLUX models for improved high-resolution image generation #8763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Pfannkuchensack
wants to merge
11
commits into
invoke-ai:main
Choose a base branch
from
Pfannkuchensack:claude/assess-dype-port-3mt4G
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
ebe2e94
docs: add DyPE implementation plan for FLUX high-resolution generation
claude 5f39572
docs: update DyPE plan with design decisions
claude f58e15e
docs: add activation threshold for DyPE auto mode
claude 55c8bdc
feat(flux): implement DyPE for high-resolution generation
claude 102bdff
Merge branch 'main' into claude/assess-dype-port-3mt4G
Pfannkuchensack c173bef
feat(flux): add DyPE preset selector to Linear UI
Pfannkuchensack 19c6b56
feat(flux): add DyPE preset to metadata recall
Pfannkuchensack c2cdf05
chore: remove dype-implementation-plan.md
Pfannkuchensack e9a08c7
chore(flux): bump flux_denoise version to 4.3.0
Pfannkuchensack 50f071a
chore: ruff check fix
Pfannkuchensack 29b66d0
chore: ruff format
Pfannkuchensack File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| """Dynamic Position Extrapolation (DyPE) for FLUX models. | ||
|
|
||
| DyPE enables high-resolution image generation (4K+) with pretrained FLUX models | ||
| by dynamically scaling RoPE position embeddings during the denoising process. | ||
|
|
||
| Based on: https://github.com/wildminder/ComfyUI-DyPE | ||
| """ | ||
|
|
||
| from invokeai.backend.flux.dype.base import DyPEConfig | ||
| from invokeai.backend.flux.dype.embed import DyPEEmbedND | ||
| from invokeai.backend.flux.dype.presets import DyPEPreset, get_dype_config_for_resolution | ||
|
|
||
| __all__ = [ | ||
| "DyPEConfig", | ||
| "DyPEEmbedND", | ||
| "DyPEPreset", | ||
| "get_dype_config_for_resolution", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,226 @@ | ||
| """DyPE base configuration and utilities.""" | ||
|
|
||
| import math | ||
| from dataclasses import dataclass | ||
| from typing import Literal | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| @dataclass | ||
| class DyPEConfig: | ||
| """Configuration for Dynamic Position Extrapolation.""" | ||
|
|
||
| enable_dype: bool = True | ||
| base_resolution: int = 1024 # Native training resolution | ||
| method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn" | ||
| dype_scale: float = 2.0 # Magnitude λs (0.0-8.0) | ||
| dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0) | ||
| dype_start_sigma: float = 1.0 # When DyPE decay starts | ||
|
|
||
|
|
||
| def get_mscale(scale: float, mscale_factor: float = 1.0) -> float: | ||
| """Calculate magnitude scaling factor. | ||
|
|
||
| Args: | ||
| scale: The resolution scaling factor | ||
| mscale_factor: Adjustment factor for the scaling | ||
|
|
||
| Returns: | ||
| The magnitude scaling factor | ||
| """ | ||
| if scale <= 1.0: | ||
| return 1.0 | ||
| return mscale_factor * math.log(scale) + 1.0 | ||
|
|
||
|
|
||
| def get_timestep_mscale( | ||
| scale: float, | ||
| current_sigma: float, | ||
| dype_scale: float, | ||
| dype_exponent: float, | ||
| dype_start_sigma: float, | ||
| ) -> float: | ||
| """Calculate timestep-dependent magnitude scaling. | ||
|
|
||
| The key insight of DyPE: early steps focus on low frequencies (global structure), | ||
| late steps on high frequencies (details). This function modulates the scaling | ||
| based on the current timestep/sigma. | ||
|
|
||
| Args: | ||
| scale: Resolution scaling factor | ||
| current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) | ||
| dype_scale: DyPE magnitude (λs) | ||
| dype_exponent: DyPE decay speed (λt) | ||
| dype_start_sigma: Sigma threshold to start decay | ||
|
|
||
| Returns: | ||
| Timestep-modulated scaling factor | ||
| """ | ||
| if scale <= 1.0: | ||
| return 1.0 | ||
|
|
||
| # Normalize sigma to [0, 1] range relative to start_sigma | ||
| if current_sigma >= dype_start_sigma: | ||
| t_normalized = 1.0 | ||
| else: | ||
| t_normalized = current_sigma / dype_start_sigma | ||
|
|
||
| # Apply exponential decay: stronger extrapolation early, weaker late | ||
| # decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late | ||
| decay = math.exp(-dype_exponent * (1.0 - t_normalized)) | ||
|
|
||
| # Base mscale from resolution | ||
| base_mscale = get_mscale(scale) | ||
|
|
||
| # Interpolate between base_mscale and 1.0 based on decay and dype_scale | ||
| # When decay=1 (early): use scaled value | ||
| # When decay=0 (late): use base value | ||
| scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay | ||
|
|
||
| return scaled_mscale | ||
|
|
||
|
|
||
| def compute_vision_yarn_freqs( | ||
| pos: Tensor, | ||
| dim: int, | ||
| theta: int, | ||
| scale_h: float, | ||
| scale_w: float, | ||
| current_sigma: float, | ||
| dype_config: DyPEConfig, | ||
| ) -> tuple[Tensor, Tensor]: | ||
| """Compute RoPE frequencies using NTK-aware scaling for high-resolution. | ||
|
|
||
| This method extends FLUX's position encoding to handle resolutions beyond | ||
| the 1024px training resolution by scaling the base frequency (theta). | ||
|
|
||
| The NTK-aware approach smoothly interpolates frequencies to cover larger | ||
| position ranges without breaking the attention patterns. | ||
|
|
||
| Args: | ||
| pos: Position tensor | ||
| dim: Embedding dimension | ||
| theta: RoPE base frequency | ||
| scale_h: Height scaling factor | ||
| scale_w: Width scaling factor | ||
| current_sigma: Current noise level (reserved for future timestep-aware scaling) | ||
| dype_config: DyPE configuration | ||
|
|
||
| Returns: | ||
| Tuple of (cos, sin) frequency tensors | ||
| """ | ||
| assert dim % 2 == 0 | ||
|
|
||
| # Use the larger scale for NTK calculation | ||
| scale = max(scale_h, scale_w) | ||
|
|
||
| device = pos.device | ||
| dtype = torch.float64 if device.type != "mps" else torch.float32 | ||
|
|
||
| # NTK-aware theta scaling: extends position coverage for high-res | ||
| # Formula: theta_scaled = theta * scale^(dim/(dim-2)) | ||
| # This increases the wavelength of position encodings proportionally | ||
| if scale > 1.0: | ||
| ntk_alpha = scale ** (dim / (dim - 2)) | ||
| scaled_theta = theta * ntk_alpha | ||
| else: | ||
| scaled_theta = theta | ||
|
|
||
| # Standard RoPE frequency computation | ||
| freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim | ||
| freqs = 1.0 / (scaled_theta**freq_seq) | ||
|
|
||
| # Compute angles = position * frequency | ||
| angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) | ||
|
|
||
| cos = torch.cos(angles) | ||
| sin = torch.sin(angles) | ||
|
|
||
| return cos.to(pos.dtype), sin.to(pos.dtype) | ||
|
|
||
|
|
||
| def compute_yarn_freqs( | ||
| pos: Tensor, | ||
| dim: int, | ||
| theta: int, | ||
| scale: float, | ||
| current_sigma: float, | ||
| dype_config: DyPEConfig, | ||
| ) -> tuple[Tensor, Tensor]: | ||
| """Compute RoPE frequencies using YARN/NTK method. | ||
|
|
||
| Uses NTK-aware theta scaling for high-resolution support. | ||
|
|
||
| Args: | ||
| pos: Position tensor | ||
| dim: Embedding dimension | ||
| theta: RoPE base frequency | ||
| scale: Uniform scaling factor | ||
| current_sigma: Current noise level (reserved for future use) | ||
| dype_config: DyPE configuration | ||
|
|
||
| Returns: | ||
| Tuple of (cos, sin) frequency tensors | ||
| """ | ||
| assert dim % 2 == 0 | ||
|
|
||
| device = pos.device | ||
| dtype = torch.float64 if device.type != "mps" else torch.float32 | ||
|
|
||
| # NTK-aware theta scaling | ||
| if scale > 1.0: | ||
| ntk_alpha = scale ** (dim / (dim - 2)) | ||
| scaled_theta = theta * ntk_alpha | ||
| else: | ||
| scaled_theta = theta | ||
|
|
||
| freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim | ||
| freqs = 1.0 / (scaled_theta**freq_seq) | ||
|
|
||
| angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) | ||
|
|
||
| cos = torch.cos(angles) | ||
| sin = torch.sin(angles) | ||
|
|
||
| return cos.to(pos.dtype), sin.to(pos.dtype) | ||
|
|
||
|
|
||
| def compute_ntk_freqs( | ||
| pos: Tensor, | ||
| dim: int, | ||
| theta: int, | ||
| scale: float, | ||
| ) -> tuple[Tensor, Tensor]: | ||
| """Compute RoPE frequencies using NTK method. | ||
|
|
||
| Neural Tangent Kernel approach - continuous frequency scaling without | ||
| timestep dependency. | ||
|
|
||
| Args: | ||
| pos: Position tensor | ||
| dim: Embedding dimension | ||
| theta: RoPE base frequency | ||
| scale: Scaling factor | ||
|
|
||
| Returns: | ||
| Tuple of (cos, sin) frequency tensors | ||
| """ | ||
| assert dim % 2 == 0 | ||
|
|
||
| device = pos.device | ||
| dtype = torch.float64 if device.type != "mps" else torch.float32 | ||
|
|
||
| # NTK scaling | ||
| scaled_theta = theta * (scale ** (dim / (dim - 2))) | ||
|
|
||
| freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim | ||
| freqs = 1.0 / (scaled_theta**freq_seq) | ||
|
|
||
| angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) | ||
|
|
||
| cos = torch.cos(angles) | ||
| sin = torch.sin(angles) | ||
|
|
||
| return cos.to(pos.dtype), sin.to(pos.dtype) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.