1+ import re
2+ import torch
3+ from .tools import VariantSupport
4+ from .base_node import NODE_NAME , LatentNode
5+ import comfy .model_management
6+
7+ MAX_RESOLUTION = 32768
8+
9+ # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475
10+ def slerp (val , low , high ):
11+ dims = low .shape
12+
13+ #flatten to batches
14+ low = low .reshape (dims [0 ], - 1 )
15+ high = high .reshape (dims [0 ], - 1 )
16+
17+ low_norm = low / torch .norm (low , dim = 1 , keepdim = True )
18+ high_norm = high / torch .norm (high , dim = 1 , keepdim = True )
19+
20+ # in case we divide by zero
21+ low_norm [low_norm != low_norm ] = 0.0
22+ high_norm [high_norm != high_norm ] = 0.0
23+
24+ omega = torch .acos ((low_norm * high_norm ).sum (1 ))
25+ so = torch .sin (omega )
26+ res = (torch .sin ((1.0 - val )* omega )/ so ).unsqueeze (1 )* low + (torch .sin (val * omega )/ so ).unsqueeze (1 ) * high
27+ return res .reshape (dims )
28+
29+ @VariantSupport ()
30+ class SeedInterpNoise (LatentNode ):
31+ """
32+ Produce a batch of noise tensors whose seeds advance by +1 after
33+ every (interp_steps + 1) frames, with SLERP interpolation between
34+ the two anchor seeds.
35+
36+ Example (start_seed=1, interp_steps=2, frames=7):
37+ seeds: 1 1--2 1--2 2 2--3 2--3 3
38+ ε0 ε1 ε2 ε3 ε4 ε5 ε6
39+ ▲ ▲ ▲ ▲
40+ | |__________| |
41+ |______ segment 1 _____| |___ seg 2 …
42+ """
43+
44+ @classmethod
45+ def INPUT_TYPES (cls ):
46+ return {"required" : {
47+ "source" : (["CPU" , "GPU" ],),
48+ "start_seed" : ("INT" , {"default" : 0 , "min" : 0 ,
49+ "max" : 0xffffffffffffffff }),
50+ "frames" : ("INT" , {"default" : 8 , "min" : 1 ,
51+ "max" : 9999999 }),
52+ "interp_steps" : ("INT" , {"default" : 1 , "min" : 0 ,
53+ "max" : 1024 }),
54+ "width" : ("INT" , {"default" : 512 , "min" : 64 ,
55+ "max" : MAX_RESOLUTION , "step" : 8 }),
56+ "height" : ("INT" , {"default" : 512 , "min" : 64 ,
57+ "max" : MAX_RESOLUTION , "step" : 8 }),
58+ }}
59+
60+ RETURN_TYPES = ("LATENT" ,) # matches NoisyLatentImage
61+ FUNCTION = "build"
62+
63+ def build (self , source , start_seed , frames , interp_steps ,
64+ width , height ):
65+ device = "cpu" if source == "CPU" \
66+ else comfy .model_management .get_torch_device ()
67+ c , h , w = 4 , height // 8 , width // 8
68+
69+ # pre-allocate output tensor
70+ batch = torch .empty ((frames , c , h , w ),
71+ dtype = torch .float32 , device = device )
72+
73+ # helper: epsilon(seed)
74+ def eps (seed ):
75+ g = torch .Generator (device ).manual_seed (seed )
76+ return torch .randn ((c , h , w ), generator = g , device = device )
77+
78+ seg_len = interp_steps + 1 # ε_a + k in-betweens + ε_b
79+ cur_seed = start_seed
80+ frame_idx = 0
81+
82+ while frame_idx < frames :
83+ eps_a = eps (cur_seed )
84+ eps_b = eps (cur_seed + 1 )
85+
86+ # anchor A
87+ batch [frame_idx ] = eps_a
88+ frame_idx += 1
89+ if frame_idx >= frames :
90+ break
91+
92+ # interpolated frames
93+ for k in range (1 , seg_len ):
94+ if frame_idx >= frames :
95+ break
96+ t = k / seg_len
97+ batch [frame_idx ] = slerp (t , eps_a , eps_b )
98+ frame_idx += 1
99+
100+ # anchor B (start of next segment) only if we still need frames
101+ cur_seed += 1
102+
103+ return ({"samples" : batch .cpu ()}, )
104+
105+ LATENT_NODE_CLASS_MAPPINGS = {
106+ "SeedInterpNoise" : SeedInterpNoise ,
107+ }
108+
109+ LATENT_NODE_DISPLAY_NAME_MAPPINGS = {
110+ "SeedInterpNoise" : f"Seed Interp Noise | { NODE_NAME } " ,
111+ }
0 commit comments