Skip to content

Commit eae6b03

Browse files
committed
Add stack passing
1 parent ac0c2f4 commit eae6b03

15 files changed

+126
-74
lines changed

more_math/AudioMathNode.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .Parser.MathExprLexer import MathExprLexer
1515
from .Parser.MathExprParser import MathExprParser
1616
import re
17+
from .Stack import MrmthStack
1718

1819
class AudioMathNode(io.ComfyNode):
1920
"""
@@ -40,15 +41,17 @@ def define_schema(cls) -> io.Schema:
4041
options=["tile", "error", "pad"],
4142
default="error",
4243
tooltip="How to handle mismatched image batch sizes. tile: repeat shorter inputs; error: raise error on mismatch; pad: treat missing frames as zero."
43-
)
44+
),
45+
MrmthStack.Input(id="stack", tooltip="Access stack between nodes",optional=True)
4446
],
4547
outputs=[
4648
io.Audio.Output(),
49+
MrmthStack.Output(),
4750
],
4851
)
4952

5053
@classmethod
51-
def check_lazy_status(cls, Expression, V, F, length_mismatch="tile"):
54+
def check_lazy_status(cls, Expression, V, F, length_mismatch="tile",stack=[]):
5255

5356
input_stream = InputStream(Expression)
5457
lexer = MathExprLexer(input_stream)
@@ -80,7 +83,7 @@ def check_lazy_status(cls, Expression, V, F, length_mismatch="tile"):
8083
return needed1
8184

8285
@classmethod
83-
def execute(cls, V, F, Expression, length_mismatch="tile"):
86+
def execute(cls, V, F, Expression, length_mismatch="tile",stack=[]):
8487
# Identify all present audio inputs and their keys
8588
tensor_keys = [k for k, v in V.items() if v is not None and isinstance(v, dict) and "waveform" in v]
8689
if not tensor_keys:
@@ -145,7 +148,7 @@ def execute(cls, V, F, Expression, length_mismatch="tile"):
145148
variables[k] = val if val is not None else 0.0
146149

147150
tree = parse_expr(Expression);
148-
visitor = UnifiedMathVisitor(variables, a_w.shape,a_w.device)
151+
visitor = UnifiedMathVisitor(variables, a_w.shape,a_w.device,state_storage=stack)
149152
result = visitor.visit(tree)
150153
result = as_tensor(result, a_w.shape)
151-
return ({"waveform":result,"sample_rate":sample_rate},)
154+
return ({"waveform":result,"sample_rate":sample_rate},stack)

more_math/ClipMathNode.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .Parser.MathExprLexer import MathExprLexer
55
from .Parser.MathExprParser import MathExprParser
66
import re
7+
from .Stack import MrmthStack
78

89

910
class CLIPMathNode(io.ComfyNode):
@@ -26,17 +27,19 @@ def define_schema(cls) -> io.Schema:
2627
options=["tile", "error", "pad"],
2728
default="error",
2829
tooltip="How to handle mismatched layer counts. For models, this usually defaults to broadcast (zero for missing layers)."
29-
)
30+
),
31+
MrmthStack.Input(id="stack", tooltip="Access stack between nodes",optional=True)
3032
],
3133
outputs=[
3234
io.Clip.Output(),
35+
MrmthStack.Output(),
3336
],
3437
)
3538

3639
tooltip = cleandoc(__doc__)
3740

3841
@classmethod
39-
def check_lazy_status(cls, Expression, V, F, length_mismatch="tile"):
42+
def check_lazy_status(cls, Expression, V, F, length_mismatch="tile",stack=[]):
4043

4144
input_stream = InputStream(Expression)
4245
lexer = MathExprLexer(input_stream)
@@ -68,7 +71,7 @@ def check_lazy_status(cls, Expression, V, F, length_mismatch="tile"):
6871
return needed1
6972

7073
@classmethod
71-
def execute(cls, V, F, Expression, length_mismatch="tile") -> io.NodeOutput:
74+
def execute(cls, V, F, Expression, length_mismatch="tile",stack=[]) -> io.NodeOutput:
7275
# Determine reference CLIP
7376
a = V.get("V0")
7477
if a is None:
@@ -93,9 +96,9 @@ def execute(cls, V, F, Expression, length_mismatch="tile") -> io.NodeOutput:
9396
# The prompt says aliases are supported in check_lazy_status. Variables map in helper handles logic.
9497
aliases = {"a": "V0", "b": "V1", "c": "V2", "d": "V3", "w": "F0", "x": "F1", "y": "F2", "z": "F3"}
9598

96-
patches = calculate_patches_autogrow(Expression, V=patchers_V, F=F, mapping=aliases)
99+
patches = calculate_patches_autogrow(Expression, V=patchers_V, F=F, mapping=aliases,stack=stack)
97100

98101
out_clip = a.clone()
99102
if patches:
100103
out_clip.add_patches(patches, 1.0, 1.0)
101-
return (out_clip,)
104+
return (out_clip,stack)

more_math/ConditioningMathNode.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .Parser.MathExprParser import MathExprParser
99
import re
1010
import copy
11+
from .Stack import MrmthStack
1112

1213
class ConditioningMathNode(io.ComfyNode):
1314
"""
@@ -36,15 +37,17 @@ def define_schema(cls) -> io.Schema:
3637
default="error",
3738
tooltip="How to handle mismatched image batch sizes. tile: repeat shorter inputs; error: raise error on mismatch; pad: treat missing frames as zero."
3839
),
39-
io.Int.Input(id="batching")
40+
io.Int.Input(id="batching"),
41+
MrmthStack.Input(id="stack",optional=True)
4042
],
4143
outputs=[
4244
io.Conditioning.Output(is_output_list=True),
45+
MrmthStack.Output()
4346
],
4447
)
4548

4649
@classmethod
47-
def check_lazy_status(cls, Expression,Expression_pi, V, F,batching, length_mismatch="tile"):
50+
def check_lazy_status(cls, Expression,Expression_pi, V, F,batching, length_mismatch="tile",stack=[]):
4851

4952
input_stream = InputStream(Expression)
5053
lexer = MathExprLexer(input_stream)
@@ -81,7 +84,7 @@ def check_lazy_status(cls, Expression,Expression_pi, V, F,batching, length_misma
8184
return needed1
8285

8386
@classmethod
84-
def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile"):
87+
def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile",stack=[]):
8588
# Identify all present conditioning inputs
8689
tensor_keys = [k for k, v in V.items() if v is not None and isinstance(v, list) and len(v) > 0]
8790
if not tensor_keys:
@@ -90,7 +93,6 @@ def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile
9093
# Extract tensors and pooled outputs
9194
tensors = {}
9295
pooled_outputs = {}
93-
ss = dict()
9496
for key in tensor_keys:
9597
conditioning = V[key]
9698
tensors[key] = conditioning[0][0]
@@ -152,7 +154,7 @@ def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile
152154

153155
# Execute Expression (Main Tensor)
154156
tree = parse_expr(Expression)
155-
visitor = UnifiedMathVisitor(variables, a.shape,a.device, state_storage=ss)
157+
visitor = UnifiedMathVisitor(variables, a.shape,a.device, state_storage=stack)
156158
rtensor = visitor.visit(tree)
157159
rtensor = as_tensor(rtensor, a.shape)
158160

@@ -193,7 +195,7 @@ def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile
193195

194196
# Execute Expression_pi (Pooled Output)
195197
tree_pi = parse_expr(Expression_pi)
196-
visitor_pi = UnifiedMathVisitor(variables_pi, a_p.shape,a_p.device, state_storage=ss)
198+
visitor_pi = UnifiedMathVisitor(variables_pi, a_p.shape,a_p.device, state_storage=stack)
197199
rpooled_raw = visitor_pi.visit(tree_pi)
198200
rpooled = as_tensor(rpooled_raw, a_p.shape)
199201

more_math/FloatMathNode.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .Parser.MathExprLexer import MathExprLexer
1010
from .Parser.MathExprParser import MathExprParser
1111
import re
12+
from .Stack import MrmthStack
1213

1314

1415
class FloatMathNode(io.ComfyNode):
@@ -29,16 +30,18 @@ def define_schema(cls) -> io.Schema:
2930
inputs=[
3031
io.Autogrow.Input(id="V",template=io.Autogrow.TemplatePrefix(io.Float.Input("values"), prefix="V", min=1, max=50)),
3132
io.String.Input(id="FloatFunc", default="a*(1-w)+b*w", tooltip="Expression to use on inputs"),
33+
MrmthStack.Input(id="stack", tooltip="Access stack between nodes",optional=True)
3234
],
3335
outputs=[
3436
io.Float.Output(),
37+
MrmthStack.Output(),
3538
],
3639
)
3740

3841
tooltip = cleandoc(__doc__)
3942

4043
@classmethod
41-
def check_lazy_status(cls, FloatFunc, V):
44+
def check_lazy_status(cls, FloatFunc, V,stack=[]):
4245
input_stream = InputStream(FloatFunc)
4346
lexer = MathExprLexer(input_stream)
4447
stream = CommonTokenStream(lexer)
@@ -69,7 +72,7 @@ def check_lazy_status(cls, FloatFunc, V):
6972
return needed1
7073

7174
@classmethod
72-
def execute(cls, FloatFunc, V):
75+
def execute(cls, FloatFunc, V,stack=[]):
7376

7477
variables = {}
7578
# Populate aliases
@@ -95,9 +98,9 @@ def execute(cls, FloatFunc, V):
9598
tree = parse_expr(FloatFunc);
9699
# scalar execution
97100
# UnifiedMathVisitor expects variables and a shape. Shape [1] for scalar?
98-
visitor = UnifiedMathVisitor(variables, [1])
101+
visitor = UnifiedMathVisitor(variables, [1],state_storage=stack)
99102
result = visitor.visit(tree)
100103
# Result might be float or tensor(scalar)
101104
if torch.is_tensor(result):
102105
result = result[0].item()
103-
return (float(result),)
106+
return (float(result),stack)

more_math/GuiderMathNode.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from numpy import stack
12
import torch
23
import re
34
from antlr4 import InputStream, CommonTokenStream
@@ -19,6 +20,7 @@
1920
import comfy.utils
2021
import comfy.hooks
2122
import comfy.samplers
23+
from .Stack import MrmthStack
2224

2325

2426
class GuiderMathNode(io.ComfyNode):
@@ -37,14 +39,16 @@ def define_schema(cls) -> io.Schema:
3739
io.Autogrow.Input(id="F", template=io.Autogrow.TemplatePrefix(io.Float.Input("float", default=0.0, optional=True, lazy=True, force_input=True), prefix="F", min=1, max=50)),
3840
io.String.Input(id="Expression", default="G0*(1-F0)+G1*F0", tooltip="Expression to apply on input guiders. Aliases: a=G0, b=G1, c=G2, d=G3, w=F0, x=F1, y=F2, z=F3. Context: steps, current_step"),
3941
io.String.Input(id="Expression1", default="G0*(1-F0)+G1*F0", tooltip="Expression to apply after generation finishes."),
42+
MrmthStack.Input(id="stack", tooltip="Access stack between nodes",optional=True)
4043
],
4144
outputs=[
4245
io.Guider.Output(),
46+
MrmthStack.Output()
4347
],
4448
)
4549

4650
@classmethod
47-
def check_lazy_status(cls, Expression,Expression1, V, F):
51+
def check_lazy_status(cls, Expression,Expression1, V, F,stack=[]):
4852
input_stream = InputStream(Expression)
4953
input_stream1 = InputStream(Expression1)
5054
lexer = MathExprLexer(input_stream)
@@ -79,12 +83,12 @@ def check_lazy_status(cls, Expression,Expression1, V, F):
7983
return needed1
8084

8185
@classmethod
82-
def execute(cls, V, F, Expression,Expression1):
83-
return (MathGuider(V, F, Expression,Expression1),)
86+
def execute(cls, V, F, Expression,Expression1,stack=[]):
87+
return (MathGuider(V, F, Expression,Expression1),stack)
8488

8589

8690
class MathGuider:
87-
def __init__(self, V, F, expression,expression1):
91+
def __init__(self, V, F, expression,expression1,stack=[]):
8892
self.V = V
8993
self.F = F
9094
self.expression = expression
@@ -94,7 +98,7 @@ def __init__(self, V, F, expression,expression1):
9498
self.sigmas = None
9599
self.current_step = 0
96100
self.steps = 0
97-
self.stck = {}
101+
self.stck = stack
98102

99103
@property
100104
def model_patcher(self):
@@ -167,7 +171,7 @@ def setVars(self, x, sigma, seed, g_results):
167171
"c": g_results.get("V2", make_zero_like(eval_samples)),
168172
"d": g_results.get("V3", make_zero_like(eval_samples)),
169173
})
170-
174+
171175
v_stacked, v_cnt = get_v_variable(g_results)
172176
if v_stacked is not None:
173177
variables["V"] = v_stacked

more_math/ImageMathNode.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .Parser.MathExprLexer import MathExprLexer
66
from .Parser.MathExprParser import MathExprParser
77
import re
8+
from .Stack import MrmthStack
89

910
class ImageMathNode(io.ComfyNode):
1011
"""
@@ -31,15 +32,17 @@ def define_schema(cls) -> io.Schema:
3132
options=["tile", "error", "pad"],
3233
default="error",
3334
tooltip="How to handle mismatched image batch sizes. tile: repeat shorter inputs; error: raise error on mismatch; pad: treat missing frames as zero."
34-
)
35+
),
36+
MrmthStack.Input(id="stack", tooltip="Access stack between nodes",optional=True)
3537
],
3638
outputs=[
3739
io.Image.Output(),
40+
MrmthStack.Output(),
3841
],
3942
)
4043

4144
@classmethod
42-
def check_lazy_status(cls, Expression, V, F, length_mismatch="tile"):
45+
def check_lazy_status(cls, Expression, V, F, length_mismatch="tile",stack=[]):
4346

4447
input_stream = InputStream(Expression)
4548
lexer = MathExprLexer(input_stream)
@@ -71,7 +74,7 @@ def check_lazy_status(cls, Expression, V, F, length_mismatch="tile"):
7174
return needed1
7275

7376
@classmethod
74-
def execute(cls, V, F, Expression, length_mismatch="error"):
77+
def execute(cls, V, F, Expression, length_mismatch="error",stack=[]):
7578
# I and F are Autogrow.Type which is dict[str, Any]
7679

7780
# Identify all present tensors and their keys
@@ -144,7 +147,7 @@ def execute(cls, V, F, Expression, length_mismatch="error"):
144147
variables[k] = val if val is not None else 0.0
145148

146149
tree = parse_expr(Expression);
147-
visitor = UnifiedMathVisitor(variables, ae.shape,ae.device)
150+
visitor = UnifiedMathVisitor(variables, ae.shape,ae.device,state_storage=stack)
148151
result = visitor.visit(tree)
149152
result = as_tensor(result, ae.shape)
150-
return (result,)
153+
return (result,stack)

more_math/LatentMathNode.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .Parser.MathExprParser import MathExprParser
1818
import re
1919
from comfy.nested_tensor import NestedTensor
20+
from .Stack import MrmthStack
2021

2122
class LatentMathNode(io.ComfyNode):
2223
"""
@@ -43,17 +44,20 @@ def define_schema(cls) -> io.Schema:
4344
default="error",
4445
tooltip="How to handle mismatched latent batch sizes. tile: repeat shorter inputs; error: raise error on mismatch; pad: treat missing frames as zero."
4546
),
46-
io.Int.Input(id="batching")
47+
io.Int.Input(id="batching"),
48+
MrmthStack.Input(id="stack", tooltip="Access stack between nodes",optional=True)
4749
],
4850
outputs=[
4951
io.Latent.Output(is_output_list=True),
52+
MrmthStack.Output(),
53+
5054
],
5155
)
5256

5357
tooltip = cleandoc(__doc__)
5458

5559
@classmethod
56-
def check_lazy_status(cls, Expression, V, F,batching, length_mismatch="tile"):
60+
def check_lazy_status(cls, Expression, V, F,batching, length_mismatch="tile",stack=[]):
5761

5862
input_stream = InputStream(Expression)
5963
lexer = MathExprLexer(input_stream)
@@ -85,7 +89,7 @@ def check_lazy_status(cls, Expression, V, F,batching, length_mismatch="tile"):
8589
return needed1
8690

8791
@classmethod
88-
def execute(cls, V, F, Expression,batching, length_mismatch="tile") -> io.NodeOutput:
92+
def execute(cls, V, F, Expression,batching, length_mismatch="tile",stack=[]) -> io.NodeOutput:
8993
# Determine reference latent
9094
ref_latent = None
9195
for lat in V.values():
@@ -198,7 +202,7 @@ def execute(cls, V, F, Expression,batching, length_mismatch="tile") -> io.NodeOu
198202
for k, v in F.items():
199203
variables[k] = v if v is not None else 0.0
200204

201-
visitor = UnifiedMathVisitor(variables, ae.shape,ae.device)
205+
visitor = UnifiedMathVisitor(variables, ae.shape,ae.device,state_storage=stack)
202206
result_t = as_tensor(visitor.visit(tree), ae.shape)
203207

204208
result_latent = ref_latent.copy()
@@ -221,7 +225,7 @@ def execute(cls, V, F, Expression,batching, length_mismatch="tile") -> io.NodeOu
221225
else:
222226
rl["samples"] = result_t
223227
results1.append(rl)
224-
return (results1,)
228+
return (results1,stack)
225229
rl = result_latent.copy()
226230
rl["samples"] = result_t
227-
return ([rl],)
231+
return ([rl],stack)

0 commit comments

Comments
 (0)