Skip to content

Commit 9e2e4d4

Browse files
committed
make conditioning math batchable
1 parent db8000d commit 9e2e4d4

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

more_math/ConditioningMathNode.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from unittest import result
12
import torch
23
from .helper_functions import generate_dim_variables, parse_expr, getIndexTensorAlongDim, as_tensor, normalize_to_common_shape, make_zero_like
34
from .Parser.UnifiedMathVisitor import UnifiedMathVisitor
@@ -26,22 +27,23 @@ def define_schema(cls) -> io.Schema:
2627
inputs=[
2728
io.Autogrow.Input(id="V",template=io.Autogrow.TemplatePrefix(io.Conditioning.Input("values"), prefix="V", min=1, max=50)),
2829
io.Autogrow.Input(id="F", template=io.Autogrow.TemplatePrefix(io.Float.Input("float", default=0.0, optional=True, lazy=True, force_input=True), prefix="F", min=1, max=50)),
29-
io.String.Input(id="Expression", default="I0*(1-F0)+I1*F0", tooltip="Expression to apply on tensor part of conditioning"),
30-
io.String.Input(id="Expression_pi", default="I0*(1-F0)+I1*F0", tooltip="Expression to apply on pooled_input part of conditioning"),
30+
io.String.Input(id="Expression",display_name="Tensor expr.", default="I0*(1-F0)+I1*F0", tooltip="Expression to apply on tensor part of conditioning"),
31+
io.String.Input(id="Expression_pi",display_name="pooled output expr.", default="I0*(1-F0)+I1*F0", tooltip="Expression to apply on pooled_input part of conditioning"),
3132
io.Combo.Input(
3233
id="length_mismatch",
3334
options=["tile", "error", "pad"],
3435
default="error",
3536
tooltip="How to handle mismatched image batch sizes. tile: repeat shorter inputs; error: raise error on mismatch; pad: treat missing frames as zero."
36-
)
37+
),
38+
io.Int.Input(id="batching")
3739
],
3840
outputs=[
39-
io.Conditioning.Output(),
41+
io.Conditioning.Output(is_output_list=True),
4042
],
4143
)
4244

4345
@classmethod
44-
def check_lazy_status(cls, Expression,Expression_pi, V, F, length_mismatch="tile"):
46+
def check_lazy_status(cls, Expression,Expression_pi, V, F,batching, length_mismatch="tile"):
4547

4648
input_stream = InputStream(Expression)
4749
lexer = MathExprLexer(input_stream)
@@ -78,7 +80,7 @@ def check_lazy_status(cls, Expression,Expression_pi, V, F, length_mismatch="tile
7880
return needed1
7981

8082
@classmethod
81-
def execute(cls, V, F, Expression, Expression_pi, length_mismatch="tile"):
83+
def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile"):
8284
# Identify all present conditioning inputs
8385
tensor_keys = [k for k, v in V.items() if v is not None and isinstance(v, list) and len(v) > 0]
8486
if not tensor_keys:
@@ -92,6 +94,7 @@ def execute(cls, V, F, Expression, Expression_pi, length_mismatch="tile"):
9294
conditioning = V[key]
9395
tensors[key] = conditioning[0][0]
9496
# pooled_output is optional in the dict
97+
9598
pooled_outputs[key] = conditioning[0][1].get("pooled_output")
9699

97100
# Normalize main tensors
@@ -169,17 +172,20 @@ def execute(cls, V, F, Expression, Expression_pi, length_mismatch="tile"):
169172
rpooled_raw = visitor_pi.visit(tree_pi)
170173
rpooled = as_tensor(rpooled_raw, a_p.shape)
171174

172-
# Clone result structure
173-
import copy
174-
# Conditioning is often a list of lists/tuples: [[tensor, dict], ...]
175-
# We assume the first element is the main one to update
175+
176+
if rtensor is None:
177+
rtensor = torch.zeros([1])
178+
if rpooled is None:
179+
rpooled = torch.zeros([1])
180+
res = torch.split_copy(rtensor,batching) if batching>0 else [rtensor]
181+
rpld = torch.split_copy(rpooled,batching) if batching>0 else [rpooled]
176182
res_list = []
177-
for i, entry in enumerate(V.get("V0", [])):
178-
if i == 0:
179-
# Update first entry with result
180-
new_dict = copy.deepcopy(entry[1])
181-
new_dict["pooled_output"] = rpooled
182-
res_list.append([rtensor, new_dict])
183-
else:
184-
res_list.append(copy.deepcopy(entry))
183+
for i in range(max(len(res),len(rpld))):
184+
result_tensor = res[i] if i<len(res) else torch.zeros([1])
185+
result_pooled = rpld[i] if i<len(rpld) else torch.zeros([1])
186+
res_list.append(V["V0"])
187+
res_list[i][0][0] = result_tensor
188+
if(len(res_list[i][0])==1):
189+
res_list[i][0].append({"pooled_output",result_pooled})
190+
else: res_list[i][0][1]["pooled_output"]=result_pooled
185191
return (res_list,)

0 commit comments

Comments
 (0)