Skip to content

Commit 4871c54

Browse files
authored
Merge pull request #4 from deforum-art/akatz-dev
Updated with new latent node, bumped pyproject version
2 parents 18f6c36 + 0c04925 commit 4871c54

File tree

4 files changed

+118
-2
lines changed

4 files changed

+118
-2
lines changed

__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
from .utility_nodes import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
33
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
44
from .lazy_nodes import GENERAL_NODE_CLASS_MAPPINGS, GENERAL_NODE_DISPLAY_NAME_MAPPINGS
5+
from .latent_nodes import LATENT_NODE_CLASS_MAPPINGS, LATENT_NODE_DISPLAY_NAME_MAPPINGS
56

67
NODE_CLASS_MAPPINGS = {}
78
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
89
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
910
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
1011
NODE_CLASS_MAPPINGS.update(GENERAL_NODE_CLASS_MAPPINGS)
12+
NODE_CLASS_MAPPINGS.update(LATENT_NODE_CLASS_MAPPINGS)
1113

1214
NODE_DISPLAY_NAME_MAPPINGS = {}
1315
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
1416
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
1517
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
1618
NODE_DISPLAY_NAME_MAPPINGS.update(GENERAL_NODE_DISPLAY_NAME_MAPPINGS)
19+
NODE_DISPLAY_NAME_MAPPINGS.update(LATENT_NODE_DISPLAY_NAME_MAPPINGS)

base_node.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ class DebugNode(BaseNode):
2323

2424
class ListNode(BaseNode):
2525
CATEGORY = f"{NODE_NAME}/Lists"
26-
26+
27+
class LatentNode(BaseNode):
28+
CATEGORY = f"{NODE_NAME}/Latent"

latent_nodes.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import re
2+
import torch
3+
from .tools import VariantSupport
4+
from .base_node import NODE_NAME, LatentNode
5+
import comfy.model_management
6+
7+
MAX_RESOLUTION=32768
8+
9+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475
10+
def slerp(val, low, high):
11+
dims = low.shape
12+
13+
#flatten to batches
14+
low = low.reshape(dims[0], -1)
15+
high = high.reshape(dims[0], -1)
16+
17+
low_norm = low/torch.norm(low, dim=1, keepdim=True)
18+
high_norm = high/torch.norm(high, dim=1, keepdim=True)
19+
20+
# in case we divide by zero
21+
low_norm[low_norm != low_norm] = 0.0
22+
high_norm[high_norm != high_norm] = 0.0
23+
24+
omega = torch.acos((low_norm*high_norm).sum(1))
25+
so = torch.sin(omega)
26+
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
27+
return res.reshape(dims)
28+
29+
@VariantSupport()
30+
class SeedInterpNoise(LatentNode):
31+
"""
32+
Produce a batch of noise tensors whose seeds advance by +1 after
33+
every (interp_steps + 1) frames, with SLERP interpolation between
34+
the two anchor seeds.
35+
36+
Example (start_seed=1, interp_steps=2, frames=7):
37+
seeds: 1 1--2 1--2 2 2--3 2--3 3
38+
ε0 ε1 ε2 ε3 ε4 ε5 ε6
39+
▲ ▲ ▲ ▲
40+
| |__________| |
41+
|______ segment 1 _____| |___ seg 2 …
42+
"""
43+
44+
@classmethod
45+
def INPUT_TYPES(cls):
46+
return {"required": {
47+
"source": (["CPU", "GPU"],),
48+
"start_seed": ("INT", {"default": 0, "min": 0,
49+
"max": 0xffffffffffffffff}),
50+
"frames": ("INT", {"default": 8, "min": 1,
51+
"max": 9999999}),
52+
"interp_steps": ("INT", {"default": 1, "min": 0,
53+
"max": 1024}),
54+
"width": ("INT", {"default": 512, "min": 64,
55+
"max": MAX_RESOLUTION, "step": 8}),
56+
"height": ("INT", {"default": 512, "min": 64,
57+
"max": MAX_RESOLUTION, "step": 8}),
58+
}}
59+
60+
RETURN_TYPES = ("LATENT",) # matches NoisyLatentImage
61+
FUNCTION = "build"
62+
63+
def build(self, source, start_seed, frames, interp_steps,
64+
width, height):
65+
device = "cpu" if source == "CPU" \
66+
else comfy.model_management.get_torch_device()
67+
c, h, w = 4, height // 8, width // 8
68+
69+
# pre-allocate output tensor
70+
batch = torch.empty((frames, c, h, w),
71+
dtype=torch.float32, device=device)
72+
73+
# helper: epsilon(seed)
74+
def eps(seed):
75+
g = torch.Generator(device).manual_seed(seed)
76+
return torch.randn((c, h, w), generator=g, device=device)
77+
78+
seg_len = interp_steps + 1 # ε_a + k in-betweens + ε_b
79+
cur_seed = start_seed
80+
frame_idx = 0
81+
82+
while frame_idx < frames:
83+
eps_a = eps(cur_seed)
84+
eps_b = eps(cur_seed + 1)
85+
86+
# anchor A
87+
batch[frame_idx] = eps_a
88+
frame_idx += 1
89+
if frame_idx >= frames:
90+
break
91+
92+
# interpolated frames
93+
for k in range(1, seg_len):
94+
if frame_idx >= frames:
95+
break
96+
t = k / seg_len
97+
batch[frame_idx] = slerp(t, eps_a, eps_b)
98+
frame_idx += 1
99+
100+
# anchor B (start of next segment) only if we still need frames
101+
cur_seed += 1
102+
103+
return ({"samples": batch.cpu()}, )
104+
105+
LATENT_NODE_CLASS_MAPPINGS = {
106+
"SeedInterpNoise": SeedInterpNoise,
107+
}
108+
109+
LATENT_NODE_DISPLAY_NAME_MAPPINGS = {
110+
"SeedInterpNoise": f"Seed Interp Noise | {NODE_NAME}",
111+
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-deforum"
33
description = "Custom nodes for deforum workflows"
4-
version = "0.2.0"
4+
version = "0.2.1"
55
license = { file = "LICENSE.txt" }
66

77
[project.urls]

0 commit comments

Comments
 (0)