1+ #include " mlir/Analysis/TopologicalSortUtils.h"
12#include " mlir/Dialect/SCF/IR/SCF.h"
23#include " mlir/IR/BuiltinOps.h"
34#include " mlir/IR/ImplicitLocOpBuilder.h"
@@ -188,6 +189,11 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
188189 // captures and thread them in to the regions.
189190 SetVector<Value> captures;
190191 getUsedValuesDefinedAbove (wsOp.getPartitionOpHolder (), captures);
192+
193+ // Find the subgraph that should be cloned into the partition regions. The
194+ // explicit captures are the leaves of the subgraph.
195+ SetVector<Operation *> opsToClone;
196+ SmallVector<Value> explicitCaptures;
191197 for (unsigned i = 0 ; i < captures.size (); ++i) {
192198 Value capture = captures[i];
193199
@@ -198,11 +204,7 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
198204 (defOp->hasTrait <OpTrait::ConstantLike>() ||
199205 isa<RankedTensorType>(capture.getType ()))) {
200206 captures.insert (defOp->operand_begin (), defOp->operand_end ());
201- for (Region *region : wsOp.getPartitionRegions ()) {
202- b.setInsertionPointToStart (®ion->front ());
203- Value copy = b.clone (*capture.getDefiningOp ())->getResult (0 );
204- replaceAllUsesInRegionWith (capture, copy, *region);
205- }
207+ opsToClone.insert (defOp);
206208 continue ;
207209 }
208210
@@ -211,14 +213,30 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
211213 " FIXME: capturing tensor values into warp "
212214 " partitions is not supported" );
213215 }
214- wsOp->insertOperands (wsOp.getNumOperands (), capture);
215- for (Region *region : wsOp.getPartitionRegions ()) {
216+ explicitCaptures.push_back (capture);
217+ }
218+
219+ // Clone the ops into each region in topological order.
220+ opsToClone = topologicalSort (opsToClone);
221+ for (Region *region : wsOp.getPartitionRegions ()) {
222+ b.setInsertionPointToStart (®ion->front ());
223+ IRMapping mapping;
224+ for (Operation *op : opsToClone) {
225+ Value copy = b.clone (*op, mapping)->getResult (0 );
226+ mapping.map (op->getResult (0 ), copy);
227+ replaceAllUsesInRegionWith (op->getResult (0 ), copy, *region);
228+ }
229+ }
230+
231+ // Replace the leaves with explicit captures.
232+ wsOp->insertOperands (wsOp.getNumOperands (), explicitCaptures);
233+ for (Region *region : wsOp.getPartitionRegions ()) {
234+ for (Value capture : explicitCaptures) {
216235 BlockArgument arg =
217236 region->addArgument (capture.getType (), capture.getLoc ());
218237 replaceAllUsesInRegionWith (capture, arg, *region);
219238 }
220239 }
221-
222240 return success ();
223241}
224242
0 commit comments