Skip to content

Commit c930fbe

Browse files
committed
cleanup
1 parent 458901b commit c930fbe

File tree

1 file changed

+3
-37
lines changed

1 file changed

+3
-37
lines changed

more_math/LatentMathNode.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,10 @@
11
from inspect import cleandoc
2-
2+
import comfy.nested_tensor
33
from comfy_api.latest import io
4-
54
import torch
6-
75
from .helper_functions import generate_dim_variables, getIndexTensorAlongDim, comonLazy, parse_expr, eval_tensor_expr_with_tree, make_zero_like
8-
96
from .MathNodeBase import MathNodeBase
107

11-
# try to import NestedTensor type if available
12-
try:
13-
import comfy.nested_tensor as _nested_tensor_module
14-
_NESTED_TENSOR_AVAILABLE = True
15-
except Exception:
16-
_nested_tensor_module = None
17-
_NESTED_TENSOR_AVAILABLE = False
188

199

2010
class LatentMathNode(MathNodeBase):
@@ -75,64 +65,52 @@ def execute(cls, Latent, a, b=None, c=None, d=None, w=0.0, x=0.0, y=0.0, z=0.0)
7565
# parse expression once
7666
tree = parse_expr(Latent)
7767

78-
# Helper to evaluate for a single tensor
7968
def eval_single_tensor(a_t, b_t, c_t, d_t):
80-
# support tensors with >=4 dims (e.g. 4D latents [B,C,H,W] or 5D [B,T,C,H,W])
8169
ndim = a_t.ndim
82-
# use negative indexing so that channel/height/width mapping works for 4D and 5D
8370
batch_dim = 0
8471
channel_dim = -3
8572
height_dim = -2
8673
width_dim = -1
8774
time_dim = None
8875
if ndim >= 5:
89-
# time/frame dim is the one before channels when present
9076
time_dim = -4
9177

9278
B = getIndexTensorAlongDim(a_t, batch_dim)
9379
C = getIndexTensorAlongDim(a_t, channel_dim)
9480
H = getIndexTensorAlongDim(a_t, height_dim)
9581
W = getIndexTensorAlongDim(a_t, width_dim)
9682

97-
# fill scalar/value tensors
9883
width_val = a_t.shape[width_dim]
9984
height_val = a_t.shape[height_dim]
10085
channel_count = a_t.shape[channel_dim]
10186
batch_count = a_t.shape[batch_dim]
10287
frame_count = a_t.shape[time_dim] if time_dim is not None else a_t.shape[batch_dim]
10388

10489
variables = {
105-
# core tensors and floats
10690
'a': a_t, 'b': b_t, 'c': c_t, 'd': d_t,
10791
'w': w, 'x': x, 'y': y, 'z': z,
10892

10993
'X': W, 'Y': H,
11094
'B': B, 'batch': B,
11195
'C': C, 'channel': C,
112-
# scalar dims and counts
11396
'W': width_val, 'width': width_val,
11497
'H': height_val, 'height': height_val,
11598
'T': frame_count, 'batch_count': batch_count,
11699
'N': channel_count, 'channel_count': channel_count,
117100
} | generate_dim_variables(a_t)
118101

119-
# expose time/frame if present
120102
if time_dim is not None:
121103
F = getIndexTensorAlongDim(a_t, time_dim)
122104
variables.update({'frame_idx': F, 'frame': F, 'frame_count': frame_count})
123105

124106
return eval_tensor_expr_with_tree(tree, variables, a_t.shape)
125107

126-
# If input is a NestedTensor (from comfy), evaluate per-subtensor and return NestedTensor result
127108
if hasattr(a_in, 'is_nested') and getattr(a_in, 'is_nested'):
128-
# get underlying lists
129109
a_list = a_in.unbind()
130110
sizes = [t.shape[0] for t in a_list]
131-
# merge all a subtensors along batch (dim=0)
132111
merged_a = torch.cat(a_list, dim=0)
133112

134113
def merge_to_tensor(val, ref):
135-
# ref is merged_a
136114
if val is None:
137115
return make_zero_like(ref)
138116
if hasattr(val, 'is_nested') and getattr(val, 'is_nested'):
@@ -141,39 +119,27 @@ def merge_to_tensor(val, ref):
141119
if isinstance(val, (list, tuple)):
142120
return torch.cat(list(val), dim=0)
143121
if torch.is_tensor(val):
144-
# if val already matches merged shape
145122
if val.shape == ref.shape:
146123
return val
147-
# if val is per-subtensor with same per-subtensor batch, replicate
148124
try:
149125
if val.shape[0] in sizes and val.shape[1:] == a_list[0].shape[1:]:
150-
# broadcast by concatenating copies
151126
return torch.cat([val for _ in a_list], dim=0)
152127
except Exception:
153128
pass
154-
# if val has batch equal to combined, return as is
155129
if val.shape[0] == sum(sizes):
156130
return val
157-
# fallback
158131
return make_zero_like(ref)
159132

160133
merged_b = merge_to_tensor(b_in, merged_a)
161134
merged_c = merge_to_tensor(c_in, merged_a)
162135
merged_d = merge_to_tensor(d_in, merged_a)
163-
164-
# evaluate once on merged tensors
165136
merged_result = eval_single_tensor(merged_a, merged_b, merged_c, merged_d)
166137

167-
# split back into list
168138
split_results = list(merged_result.split(sizes, dim=0))
169-
if _NESTED_TENSOR_AVAILABLE and _nested_tensor_module is not None:
170-
out_samples = _nested_tensor_module.NestedTensor(split_results)
171-
else:
172-
out_samples = split_results
139+
out_samples = comfy.nested_tensor.nested_tensor.NestedTensor(split_results)
140+
173141
return ({"samples": out_samples},)
174142

175-
# Non-nested (single tensor) path
176-
# ensure b/c/d are set appropriately (zeros_like if None)
177143
def to_tensor(val, ref):
178144
if val is None:
179145
return make_zero_like(ref)

0 commit comments

Comments
 (0)