88from .Parser .MathExprParser import MathExprParser
99import re
1010import copy
11+ from .Stack import MrmthStack
1112
1213class ConditioningMathNode (io .ComfyNode ):
1314 """
@@ -36,15 +37,17 @@ def define_schema(cls) -> io.Schema:
3637 default = "error" ,
3738 tooltip = "How to handle mismatched image batch sizes. tile: repeat shorter inputs; error: raise error on mismatch; pad: treat missing frames as zero."
3839 ),
39- io .Int .Input (id = "batching" )
40+ io .Int .Input (id = "batching" ),
41+ MrmthStack .Input (id = "stack" ,optional = True )
4042 ],
4143 outputs = [
4244 io .Conditioning .Output (is_output_list = True ),
45+ MrmthStack .Output ()
4346 ],
4447 )
4548
4649 @classmethod
47- def check_lazy_status (cls , Expression ,Expression_pi , V , F ,batching , length_mismatch = "tile" ):
50+ def check_lazy_status (cls , Expression ,Expression_pi , V , F ,batching , length_mismatch = "tile" , stack = [] ):
4851
4952 input_stream = InputStream (Expression )
5053 lexer = MathExprLexer (input_stream )
@@ -81,7 +84,7 @@ def check_lazy_status(cls, Expression,Expression_pi, V, F,batching, length_misma
8184 return needed1
8285
8386 @classmethod
84- def execute (cls , V , F , Expression , Expression_pi ,batching , length_mismatch = "tile" ):
87+ def execute (cls , V , F , Expression , Expression_pi ,batching , length_mismatch = "tile" , stack = [] ):
8588 # Identify all present conditioning inputs
8689 tensor_keys = [k for k , v in V .items () if v is not None and isinstance (v , list ) and len (v ) > 0 ]
8790 if not tensor_keys :
@@ -90,7 +93,6 @@ def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile
9093 # Extract tensors and pooled outputs
9194 tensors = {}
9295 pooled_outputs = {}
93- ss = dict ()
9496 for key in tensor_keys :
9597 conditioning = V [key ]
9698 tensors [key ] = conditioning [0 ][0 ]
@@ -152,7 +154,7 @@ def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile
152154
153155 # Execute Expression (Main Tensor)
154156 tree = parse_expr (Expression )
155- visitor = UnifiedMathVisitor (variables , a .shape ,a .device , state_storage = ss )
157+ visitor = UnifiedMathVisitor (variables , a .shape ,a .device , state_storage = stack )
156158 rtensor = visitor .visit (tree )
157159 rtensor = as_tensor (rtensor , a .shape )
158160
@@ -193,7 +195,7 @@ def execute(cls, V, F, Expression, Expression_pi,batching, length_mismatch="tile
193195
194196 # Execute Expression_pi (Pooled Output)
195197 tree_pi = parse_expr (Expression_pi )
196- visitor_pi = UnifiedMathVisitor (variables_pi , a_p .shape ,a_p .device , state_storage = ss )
198+ visitor_pi = UnifiedMathVisitor (variables_pi , a_p .shape ,a_p .device , state_storage = stack )
197199 rpooled_raw = visitor_pi .visit (tree_pi )
198200 rpooled = as_tensor (rpooled_raw , a_p .shape )
199201
0 commit comments