@@ -71,15 +71,16 @@ def inner_iter_args(self):
7171 return self .body .arguments [1 :]
7272
7373
74- def dispatch_index_op_fold_results (
75- ofrs : Sequence [Union [int , Value ]],
74+ def _dispatch_index_op_fold_results (
75+ ofrs : Sequence [Union [Operation , OpView , Value , int ]],
7676) -> Tuple [List [Value ], List [int ]]:
7777 """`mlir::dispatchIndexOpFoldResults`"""
7878 dynamic_vals = []
7979 static_vals = []
8080 for ofr in ofrs :
81- if isinstance (ofr , Value ):
82- dynamic_vals .append (ofr )
81+ if isinstance (ofr , (Operation , OpView , Value )):
82+ val = _get_op_result_or_value (ofr )
83+ dynamic_vals .append (val )
8384 static_vals .append (ShapedType .get_dynamic_size ())
8485 else :
8586 static_vals .append (ofr )
@@ -92,10 +93,10 @@ class ForallOp(ForallOp):
9293
9394 def __init__ (
9495 self ,
95- lower_bounds : Sequence [Union [Value , int ]],
96- upper_bounds : Sequence [Union [Value , int ]],
96+ lower_bounds : Sequence [Union [Operation , OpView , Value , int ]],
97+ upper_bounds : Sequence [Union [Operation , OpView , Value , int ]],
9798 steps : Sequence [Union [Value , int ]],
98- iter_args : Optional [Union [Operation , OpView , Sequence [Value ]]] = None ,
99+ shared_outs : Optional [Union [Operation , OpView , Sequence [Value ]]] = None ,
99100 * ,
100101 mapping = None ,
101102 loc = None ,
@@ -106,18 +107,21 @@ def __init__(
106107 - `lower_bounds` are the values to use as lower bounds of the loop.
107108 - `upper_bounds` are the values to use as upper bounds of the loop.
108109 - `steps` are the values to use as loop steps.
109- - `iter_args ` is a list of additional loop-carried arguments or an operation
110+ - `shared_outs ` is a list of additional loop-carried arguments or an operation
110111 producing them as results.
111112 """
112- if iter_args is None :
113- iter_args = []
114- iter_args = _get_op_results_or_values (iter_args )
115-
116- dynamic_lbs , static_lbs = dispatch_index_op_fold_results (lower_bounds )
117- dynamic_ubs , static_ubs = dispatch_index_op_fold_results (upper_bounds )
118- dynamic_steps , static_steps = dispatch_index_op_fold_results (steps )
119-
120- results = [arg .type for arg in iter_args ]
113+ assert (
114+ len (lower_bounds ) == len (upper_bounds ) == len (steps )
115+ ), "Mismatch in length of lower bounds, upper bounds, and steps"
116+ if shared_outs is None :
117+ shared_outs = []
118+ shared_outs = _get_op_results_or_values (shared_outs )
119+
120+ dynamic_lbs , static_lbs = _dispatch_index_op_fold_results (lower_bounds )
121+ dynamic_ubs , static_ubs = _dispatch_index_op_fold_results (upper_bounds )
122+ dynamic_steps , static_steps = _dispatch_index_op_fold_results (steps )
123+
124+ results = [arg .type for arg in shared_outs ]
121125 super ().__init__ (
122126 results ,
123127 dynamic_lbs ,
@@ -126,7 +130,7 @@ def __init__(
126130 static_lbs ,
127131 static_ubs ,
128132 static_steps ,
129- iter_args ,
133+ shared_outs ,
130134 mapping = mapping ,
131135 loc = loc ,
132136 ip = ip ,
@@ -151,18 +155,17 @@ def induction_variables(self) -> BlockArgumentList:
151155 return self .body .arguments [: self .rank ]
152156
153157 @property
154- def inner_iter_args (self ):
158+ def inner_iter_args (self ) -> BlockArgumentList :
155159 """Returns the loop-carried arguments usable within the loop.
156160
157161 To obtain the loop-carried operands, use `iter_args`.
158162 """
159163 return self .body .arguments [self .rank :]
160164
161- @property
162165 def terminator (self ) -> InParallelOp :
163166 """
164167 Returns the loop terminator if it exists.
165- Otherwise, create a new one.
168+ Otherwise, creates a new one.
166169 """
167170 ops = self .body .operations
168171 with InsertionPoint (self .body ):
0 commit comments