Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
366 changes: 366 additions & 0 deletions comfy_extras/nodes_dype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
# adapted from https://github.com/guyyariv/DyPE

import math

import numpy as np
import torch
from typing_extensions import override

from comfy.model_patcher import ModelPatcher
from comfy_api.latest import ComfyExtension, io


def find_correction_factor(num_rotations, dim, base, max_position_embeddings):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
) # Inverse dim formula to find number of rotations


def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len):
"""Find the correction range for NTK-by-parts interpolation"""
low = np.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len))
high = np.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len))
return max(low, 0), min(high, dim - 1) # Clamp values just in case


def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func


def find_newbase_ntk(dim, base, scale):
"""Calculate the new base for NTK-aware scaling"""
return base * (scale ** (dim / (dim - 2)))


def get_1d_rotary_pos_embed(
dim: int,
pos: np.ndarray | int,
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32,
yarn=False,
max_pe_len=None,
ori_max_pe_len=64,
dype=False,
current_timestep=1.0,
):
"""
Precompute the frequency tensor for complex exponentials with RoPE.
Supports YARN interpolation for vision transformers.

Args:
dim (`int`):
Dimension of the frequency tensor.
pos (`np.ndarray` or `int`):
Position indices for the frequency tensor. [S] or scalar.
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation.
use_real (`bool`, *optional*, defaults to False):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for linear interpolation.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for NTK-Aware RoPE.
repeat_interleave_real (`bool`, *optional*, defaults to True):
If True and use_real, real and imaginary parts are interleaved with themselves to reach dim.
Otherwise, they are concatenated.
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
Data type of the frequency tensor.
yarn (`bool`, *optional*, defaults to False):
If True, use YARN interpolation combining NTK, linear, and base methods.
max_pe_len (`int`, *optional*):
Maximum position encoding length (current patches for vision models).
ori_max_pe_len (`int`, *optional*, defaults to 64):
Original maximum position encoding length (base patches for vision models).
dype (`bool`, *optional*, defaults to False):
If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling.
current_timestep (`float`, *optional*, defaults to 1.0):
Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise.

Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
If use_real=True, returns tuple of (cos, sin) tensors.
"""
assert dim % 2 == 0

if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos)

device = pos.device

if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
if not isinstance(max_pe_len, torch.Tensor):
max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device)

scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)

beta_0 = 1.25
beta_1 = 0.75
gamma_0 = 16
gamma_1 = 2

freqs_base = 1.0 / (
theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
)

freqs_linear = 1.0 / torch.einsum(
"..., f -> ... f",
scale,
(
theta
** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
),
)

new_base = find_newbase_ntk(dim, theta, scale)
if new_base.dim() > 0:
new_base = new_base.view(-1, 1)
freqs_ntk = 1.0 / torch.pow(
new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
)
if freqs_ntk.dim() > 1:
freqs_ntk = freqs_ntk.squeeze()

if dype:
beta_0 = beta_0 ** (2.0 * (current_timestep**2.0))
beta_1 = beta_1 ** (2.0 * (current_timestep**2.0))

low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
low = max(0, low)
high = min(dim // 2, high)

freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(
freqs_dtype
)
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask

if dype:
gamma_0 = gamma_0 ** (2.0 * (current_timestep**2.0))
gamma_1 = gamma_1 ** (2.0 * (current_timestep**2.0))

low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
low = max(0, low)
high = min(dim // 2, high)

freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to(
freqs_dtype
)
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask

else:
theta_ntk = theta * ntk_factor
freqs = (
1.0
/ (
theta_ntk
** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim)
)
/ linear_factor
)

freqs = pos.unsqueeze(-1) * freqs

is_npu = freqs.device.type == "npu"
if is_npu:
freqs = freqs.float()

if use_real and repeat_interleave_real:
freqs_cos = (
freqs.cos()
.repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2)
.float()
)
freqs_sin = (
freqs.sin()
.repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2)
.float()
)

if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len:
mscale = torch.where(
scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0
).to(scale)
freqs_cos = freqs_cos * mscale
freqs_sin = freqs_sin * mscale

return freqs_cos, freqs_sin
elif use_real:
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis


class FluxPosEmbed(torch.nn.Module):
def __init__(
self,
theta: int,
axes_dim: list[int],
method: str = "yarn",
dype: bool = True,
):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.base_resolution = 1024
self.base_patches = (self.base_resolution // 8) // 2
self.method = method
self.dype = dype if method != "base" else False
self.current_timestep = 1.0

def set_timestep(self, timestep: float):
"""Set current timestep for DyPE. Timestep normalized to [0, 1] where 1 is pure noise."""
self.current_timestep = timestep

def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
freqs_dtype = torch.bfloat16 if ids.device.type == "cuda" else torch.float32

for i in range(n_axes):
axis_dim = self.axes_dim[i]
axis_pos = pos[..., i]

common_kwargs = {
"dim": axis_dim,
"pos": axis_pos,
"theta": self.theta,
"repeat_interleave_real": True,
"use_real": True,
"freqs_dtype": freqs_dtype,
}

if i > 0:
max_pos = axis_pos.max().item()
current_patches = max_pos + 1

if self.method == "yarn" and current_patches > self.base_patches:
max_pe_len = torch.tensor(
current_patches, dtype=freqs_dtype, device=pos.device
)
cos, sin = get_1d_rotary_pos_embed(
**common_kwargs,
yarn=True,
max_pe_len=max_pe_len,
ori_max_pe_len=self.base_patches,
dype=self.dype,
current_timestep=self.current_timestep,
)

elif self.method == "ntk" and current_patches > self.base_patches:
base_ntk = (current_patches / self.base_patches) ** (
self.axes_dim[i] / (self.axes_dim[i] - 2)
)
ntk_factor = (
base_ntk ** (2.0 * (self.current_timestep**2.0))
if self.dype
else base_ntk
)
ntk_factor = max(1.0, ntk_factor)

cos, sin = get_1d_rotary_pos_embed(
**common_kwargs, ntk_factor=ntk_factor
)

else:
cos, sin = get_1d_rotary_pos_embed(**common_kwargs)
else:
cos, sin = get_1d_rotary_pos_embed(**common_kwargs)

cos_out.append(cos)
sin_out.append(sin)

emb_parts = []
for cos, sin in zip(cos_out, sin_out):
cos_reshaped = cos.view(*cos.shape[:-1], -1, 2)[..., :1]
sin_reshaped = sin.view(*sin.shape[:-1], -1, 2)[..., :1]
row1 = torch.cat([cos_reshaped, -sin_reshaped], dim=-1)
row2 = torch.cat([sin_reshaped, cos_reshaped], dim=-1)
matrix = torch.stack([row1, row2], dim=-2)
emb_parts.append(matrix)

emb = torch.cat(emb_parts, dim=-3)
return emb.unsqueeze(1).to(ids.device)


def apply_dype_flux(model: ModelPatcher, method: str) -> ModelPatcher:
if getattr(model.model, "_dype", None) == method:
return model

m = model.clone()
m.model._dype = method

_pe_embedder = m.model.diffusion_model.pe_embedder
_theta, _axes_dim = _pe_embedder.theta, _pe_embedder.axes_dim

pos_embedder = FluxPosEmbed(_theta, _axes_dim, method, dype=True)
m.add_object_patch("diffusion_model.pe_embedder", pos_embedder)

sigma_max = m.model.model_sampling.sigma_max.item()

def dype_wrapper_function(model_function, args_dict):
timestep_tensor = args_dict.get("timestep")
if timestep_tensor is not None and timestep_tensor.numel() > 0:
current_sigma = timestep_tensor.flatten()[0].item()

if sigma_max > 0:
normalized_timestep = min(max(current_sigma / sigma_max, 0.0), 1.0)
pos_embedder.set_timestep(normalized_timestep)

input_x, c = args_dict.get("input"), args_dict.get("c", {})
return model_function(input_x, args_dict.get("timestep"), **c)

m.set_model_unet_function_wrapper(dype_wrapper_function)

return m


class DyPEPatchModelFlux(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DyPEPatchModelFlux",
display_name="DyPE Patch Model (Flux)",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Combo.Input(
"method",
options=["yarn", "ntk", "base"],
default="yarn",
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)

@classmethod
def execute(cls, model: ModelPatcher, method: str) -> io.NodeOutput:
m = apply_dype_flux(model, method)
return io.NodeOutput(m)


class DyPEExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
DyPEPatchModelFlux,
]


async def comfy_entrypoint():
return DyPEExtension()
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes():
"nodes_rope.py",
"nodes_logic.py",
"nodes_nop.py",
"nodes_dype.py",
]

import_failed = []
Expand Down