11from inspect import cleandoc
2-
2+ import comfy . nested_tensor
33from comfy_api .latest import io
4-
54import torch
6-
75from .helper_functions import generate_dim_variables , getIndexTensorAlongDim , comonLazy , parse_expr , eval_tensor_expr_with_tree , make_zero_like
8-
96from .MathNodeBase import MathNodeBase
107
11- # try to import NestedTensor type if available
12- try :
13- import comfy .nested_tensor as _nested_tensor_module
14- _NESTED_TENSOR_AVAILABLE = True
15- except Exception :
16- _nested_tensor_module = None
17- _NESTED_TENSOR_AVAILABLE = False
188
199
2010class LatentMathNode (MathNodeBase ):
@@ -75,64 +65,52 @@ def execute(cls, Latent, a, b=None, c=None, d=None, w=0.0, x=0.0, y=0.0, z=0.0)
7565 # parse expression once
7666 tree = parse_expr (Latent )
7767
78- # Helper to evaluate for a single tensor
7968 def eval_single_tensor (a_t , b_t , c_t , d_t ):
80- # support tensors with >=4 dims (e.g. 4D latents [B,C,H,W] or 5D [B,T,C,H,W])
8169 ndim = a_t .ndim
82- # use negative indexing so that channel/height/width mapping works for 4D and 5D
8370 batch_dim = 0
8471 channel_dim = - 3
8572 height_dim = - 2
8673 width_dim = - 1
8774 time_dim = None
8875 if ndim >= 5 :
89- # time/frame dim is the one before channels when present
9076 time_dim = - 4
9177
9278 B = getIndexTensorAlongDim (a_t , batch_dim )
9379 C = getIndexTensorAlongDim (a_t , channel_dim )
9480 H = getIndexTensorAlongDim (a_t , height_dim )
9581 W = getIndexTensorAlongDim (a_t , width_dim )
9682
97- # fill scalar/value tensors
9883 width_val = a_t .shape [width_dim ]
9984 height_val = a_t .shape [height_dim ]
10085 channel_count = a_t .shape [channel_dim ]
10186 batch_count = a_t .shape [batch_dim ]
10287 frame_count = a_t .shape [time_dim ] if time_dim is not None else a_t .shape [batch_dim ]
10388
10489 variables = {
105- # core tensors and floats
10690 'a' : a_t , 'b' : b_t , 'c' : c_t , 'd' : d_t ,
10791 'w' : w , 'x' : x , 'y' : y , 'z' : z ,
10892
10993 'X' : W , 'Y' : H ,
11094 'B' : B , 'batch' : B ,
11195 'C' : C , 'channel' : C ,
112- # scalar dims and counts
11396 'W' : width_val , 'width' : width_val ,
11497 'H' : height_val , 'height' : height_val ,
11598 'T' : frame_count , 'batch_count' : batch_count ,
11699 'N' : channel_count , 'channel_count' : channel_count ,
117100 } | generate_dim_variables (a_t )
118101
119- # expose time/frame if present
120102 if time_dim is not None :
121103 F = getIndexTensorAlongDim (a_t , time_dim )
122104 variables .update ({'frame_idx' : F , 'frame' : F , 'frame_count' : frame_count })
123105
124106 return eval_tensor_expr_with_tree (tree , variables , a_t .shape )
125107
126- # If input is a NestedTensor (from comfy), evaluate per-subtensor and return NestedTensor result
127108 if hasattr (a_in , 'is_nested' ) and getattr (a_in , 'is_nested' ):
128- # get underlying lists
129109 a_list = a_in .unbind ()
130110 sizes = [t .shape [0 ] for t in a_list ]
131- # merge all a subtensors along batch (dim=0)
132111 merged_a = torch .cat (a_list , dim = 0 )
133112
134113 def merge_to_tensor (val , ref ):
135- # ref is merged_a
136114 if val is None :
137115 return make_zero_like (ref )
138116 if hasattr (val , 'is_nested' ) and getattr (val , 'is_nested' ):
@@ -141,39 +119,27 @@ def merge_to_tensor(val, ref):
141119 if isinstance (val , (list , tuple )):
142120 return torch .cat (list (val ), dim = 0 )
143121 if torch .is_tensor (val ):
144- # if val already matches merged shape
145122 if val .shape == ref .shape :
146123 return val
147- # if val is per-subtensor with same per-subtensor batch, replicate
148124 try :
149125 if val .shape [0 ] in sizes and val .shape [1 :] == a_list [0 ].shape [1 :]:
150- # broadcast by concatenating copies
151126 return torch .cat ([val for _ in a_list ], dim = 0 )
152127 except Exception :
153128 pass
154- # if val has batch equal to combined, return as is
155129 if val .shape [0 ] == sum (sizes ):
156130 return val
157- # fallback
158131 return make_zero_like (ref )
159132
160133 merged_b = merge_to_tensor (b_in , merged_a )
161134 merged_c = merge_to_tensor (c_in , merged_a )
162135 merged_d = merge_to_tensor (d_in , merged_a )
163-
164- # evaluate once on merged tensors
165136 merged_result = eval_single_tensor (merged_a , merged_b , merged_c , merged_d )
166137
167- # split back into list
168138 split_results = list (merged_result .split (sizes , dim = 0 ))
169- if _NESTED_TENSOR_AVAILABLE and _nested_tensor_module is not None :
170- out_samples = _nested_tensor_module .NestedTensor (split_results )
171- else :
172- out_samples = split_results
139+ out_samples = comfy .nested_tensor .nested_tensor .NestedTensor (split_results )
140+
173141 return ({"samples" : out_samples },)
174142
175- # Non-nested (single tensor) path
176- # ensure b/c/d are set appropriately (zeros_like if None)
177143 def to_tensor (val , ref ):
178144 if val is None :
179145 return make_zero_like (ref )
0 commit comments