diff --git a/comfy_extras/nodes_dype.py b/comfy_extras/nodes_dype.py new file mode 100644 index 000000000000..3cc0be8a714d --- /dev/null +++ b/comfy_extras/nodes_dype.py @@ -0,0 +1,366 @@ +# adapted from https://github.com/guyyariv/DyPE + +import math + +import numpy as np +import torch +from typing_extensions import override + +from comfy.model_patcher import ModelPatcher +from comfy_api.latest import ComfyExtension, io + + +def find_correction_factor(num_rotations, dim, base, max_position_embeddings): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) # Inverse dim formula to find number of rotations + + +def find_correction_range(low_ratio, high_ratio, dim, base, ori_max_pe_len): + """Find the correction range for NTK-by-parts interpolation""" + low = np.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len)) + high = np.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def find_newbase_ntk(dim, base, scale): + """Calculate the new base for NTK-aware scaling""" + return base * (scale ** (dim / (dim - 2))) + + +def get_1d_rotary_pos_embed( + dim: int, + pos: np.ndarray | int, + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, + yarn=False, + max_pe_len=None, + ori_max_pe_len=64, + dype=False, + current_timestep=1.0, +): + """ + Precompute the frequency tensor for complex exponentials with RoPE. + Supports YARN interpolation for vision transformers. + + Args: + dim (`int`): + Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): + Position indices for the frequency tensor. [S] or scalar. + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. + use_real (`bool`, *optional*, defaults to False): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for linear interpolation. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for NTK-Aware RoPE. + repeat_interleave_real (`bool`, *optional*, defaults to True): + If True and use_real, real and imaginary parts are interleaved with themselves to reach dim. + Otherwise, they are concatenated. + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): + Data type of the frequency tensor. + yarn (`bool`, *optional*, defaults to False): + If True, use YARN interpolation combining NTK, linear, and base methods. + max_pe_len (`int`, *optional*): + Maximum position encoding length (current patches for vision models). + ori_max_pe_len (`int`, *optional*, defaults to 64): + Original maximum position encoding length (base patches for vision models). + dype (`bool`, *optional*, defaults to False): + If True, enable DyPE (Dynamic Position Encoding) with timestep-aware scaling. + current_timestep (`float`, *optional*, defaults to 1.0): + Current timestep for DyPE, normalized to [0, 1] where 1 is pure noise. + + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + If use_real=True, returns tuple of (cos, sin) tensors. + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) + + device = pos.device + + if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len: + if not isinstance(max_pe_len, torch.Tensor): + max_pe_len = torch.tensor(max_pe_len, dtype=freqs_dtype, device=device) + + scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) + + beta_0 = 1.25 + beta_1 = 0.75 + gamma_0 = 16 + gamma_1 = 2 + + freqs_base = 1.0 / ( + theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim) + ) + + freqs_linear = 1.0 / torch.einsum( + "..., f -> ... f", + scale, + ( + theta + ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim) + ), + ) + + new_base = find_newbase_ntk(dim, theta, scale) + if new_base.dim() > 0: + new_base = new_base.view(-1, 1) + freqs_ntk = 1.0 / torch.pow( + new_base, (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim) + ) + if freqs_ntk.dim() > 1: + freqs_ntk = freqs_ntk.squeeze() + + if dype: + beta_0 = beta_0 ** (2.0 * (current_timestep**2.0)) + beta_1 = beta_1 ** (2.0 * (current_timestep**2.0)) + + low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len) + low = max(0, low) + high = min(dim // 2, high) + + freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to( + freqs_dtype + ) + freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask + + if dype: + gamma_0 = gamma_0 ** (2.0 * (current_timestep**2.0)) + gamma_1 = gamma_1 ** (2.0 * (current_timestep**2.0)) + + low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len) + low = max(0, low) + high = min(dim // 2, high) + + freqs_mask = 1 - linear_ramp_mask(low, high, dim // 2).to(device).to( + freqs_dtype + ) + freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask + + else: + theta_ntk = theta * ntk_factor + freqs = ( + 1.0 + / ( + theta_ntk + ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=device) / dim) + ) + / linear_factor + ) + + freqs = pos.unsqueeze(-1) * freqs + + is_npu = freqs.device.type == "npu" + if is_npu: + freqs = freqs.float() + + if use_real and repeat_interleave_real: + freqs_cos = ( + freqs.cos() + .repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2) + .float() + ) + freqs_sin = ( + freqs.sin() + .repeat_interleave(2, dim=-1, output_size=freqs.shape[-1] * 2) + .float() + ) + + if yarn and max_pe_len is not None and max_pe_len > ori_max_pe_len: + mscale = torch.where( + scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0 + ).to(scale) + freqs_cos = freqs_cos * mscale + freqs_sin = freqs_sin * mscale + + return freqs_cos, freqs_sin + elif use_real: + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +class FluxPosEmbed(torch.nn.Module): + def __init__( + self, + theta: int, + axes_dim: list[int], + method: str = "yarn", + dype: bool = True, + ): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.base_resolution = 1024 + self.base_patches = (self.base_resolution // 8) // 2 + self.method = method + self.dype = dype if method != "base" else False + self.current_timestep = 1.0 + + def set_timestep(self, timestep: float): + """Set current timestep for DyPE. Timestep normalized to [0, 1] where 1 is pure noise.""" + self.current_timestep = timestep + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + freqs_dtype = torch.bfloat16 if ids.device.type == "cuda" else torch.float32 + + for i in range(n_axes): + axis_dim = self.axes_dim[i] + axis_pos = pos[..., i] + + common_kwargs = { + "dim": axis_dim, + "pos": axis_pos, + "theta": self.theta, + "repeat_interleave_real": True, + "use_real": True, + "freqs_dtype": freqs_dtype, + } + + if i > 0: + max_pos = axis_pos.max().item() + current_patches = max_pos + 1 + + if self.method == "yarn" and current_patches > self.base_patches: + max_pe_len = torch.tensor( + current_patches, dtype=freqs_dtype, device=pos.device + ) + cos, sin = get_1d_rotary_pos_embed( + **common_kwargs, + yarn=True, + max_pe_len=max_pe_len, + ori_max_pe_len=self.base_patches, + dype=self.dype, + current_timestep=self.current_timestep, + ) + + elif self.method == "ntk" and current_patches > self.base_patches: + base_ntk = (current_patches / self.base_patches) ** ( + self.axes_dim[i] / (self.axes_dim[i] - 2) + ) + ntk_factor = ( + base_ntk ** (2.0 * (self.current_timestep**2.0)) + if self.dype + else base_ntk + ) + ntk_factor = max(1.0, ntk_factor) + + cos, sin = get_1d_rotary_pos_embed( + **common_kwargs, ntk_factor=ntk_factor + ) + + else: + cos, sin = get_1d_rotary_pos_embed(**common_kwargs) + else: + cos, sin = get_1d_rotary_pos_embed(**common_kwargs) + + cos_out.append(cos) + sin_out.append(sin) + + emb_parts = [] + for cos, sin in zip(cos_out, sin_out): + cos_reshaped = cos.view(*cos.shape[:-1], -1, 2)[..., :1] + sin_reshaped = sin.view(*sin.shape[:-1], -1, 2)[..., :1] + row1 = torch.cat([cos_reshaped, -sin_reshaped], dim=-1) + row2 = torch.cat([sin_reshaped, cos_reshaped], dim=-1) + matrix = torch.stack([row1, row2], dim=-2) + emb_parts.append(matrix) + + emb = torch.cat(emb_parts, dim=-3) + return emb.unsqueeze(1).to(ids.device) + + +def apply_dype_flux(model: ModelPatcher, method: str) -> ModelPatcher: + if getattr(model.model, "_dype", None) == method: + return model + + m = model.clone() + m.model._dype = method + + _pe_embedder = m.model.diffusion_model.pe_embedder + _theta, _axes_dim = _pe_embedder.theta, _pe_embedder.axes_dim + + pos_embedder = FluxPosEmbed(_theta, _axes_dim, method, dype=True) + m.add_object_patch("diffusion_model.pe_embedder", pos_embedder) + + sigma_max = m.model.model_sampling.sigma_max.item() + + def dype_wrapper_function(model_function, args_dict): + timestep_tensor = args_dict.get("timestep") + if timestep_tensor is not None and timestep_tensor.numel() > 0: + current_sigma = timestep_tensor.flatten()[0].item() + + if sigma_max > 0: + normalized_timestep = min(max(current_sigma / sigma_max, 0.0), 1.0) + pos_embedder.set_timestep(normalized_timestep) + + input_x, c = args_dict.get("input"), args_dict.get("c", {}) + return model_function(input_x, args_dict.get("timestep"), **c) + + m.set_model_unet_function_wrapper(dype_wrapper_function) + + return m + + +class DyPEPatchModelFlux(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DyPEPatchModelFlux", + display_name="DyPE Patch Model (Flux)", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Combo.Input( + "method", + options=["yarn", "ntk", "base"], + default="yarn", + ), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, model: ModelPatcher, method: str) -> io.NodeOutput: + m = apply_dype_flux(model, method) + return io.NodeOutput(m) + + +class DyPEExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + DyPEPatchModelFlux, + ] + + +async def comfy_entrypoint(): + return DyPEExtension() diff --git a/nodes.py b/nodes.py index 356aa63dfc71..4fb21e59701a 100644 --- a/nodes.py +++ b/nodes.py @@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes(): "nodes_rope.py", "nodes_logic.py", "nodes_nop.py", + "nodes_dype.py", ] import_failed = []