Skip to content

Commit 1d6b7dd

Browse files
authored
[Warp Specialization] Fix partition loops capture order (#6757)
Ops need to be rematerialized in topological order.
1 parent 9c480c9 commit 1d6b7dd

File tree

2 files changed

+47
-9
lines changed

2 files changed

+47
-9
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/Analysis/TopologicalSortUtils.h"
12
#include "mlir/Dialect/SCF/IR/SCF.h"
23
#include "mlir/IR/BuiltinOps.h"
34
#include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -188,6 +189,11 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
188189
// captures and thread them in to the regions.
189190
SetVector<Value> captures;
190191
getUsedValuesDefinedAbove(wsOp.getPartitionOpHolder(), captures);
192+
193+
// Find the subgraph that should be cloned into the partition regions. The
194+
// explicit captures are the leaves of the subgraph.
195+
SetVector<Operation *> opsToClone;
196+
SmallVector<Value> explicitCaptures;
191197
for (unsigned i = 0; i < captures.size(); ++i) {
192198
Value capture = captures[i];
193199

@@ -198,11 +204,7 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
198204
(defOp->hasTrait<OpTrait::ConstantLike>() ||
199205
isa<RankedTensorType>(capture.getType()))) {
200206
captures.insert(defOp->operand_begin(), defOp->operand_end());
201-
for (Region *region : wsOp.getPartitionRegions()) {
202-
b.setInsertionPointToStart(&region->front());
203-
Value copy = b.clone(*capture.getDefiningOp())->getResult(0);
204-
replaceAllUsesInRegionWith(capture, copy, *region);
205-
}
207+
opsToClone.insert(defOp);
206208
continue;
207209
}
208210

@@ -211,14 +213,30 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
211213
"FIXME: capturing tensor values into warp "
212214
"partitions is not supported");
213215
}
214-
wsOp->insertOperands(wsOp.getNumOperands(), capture);
215-
for (Region *region : wsOp.getPartitionRegions()) {
216+
explicitCaptures.push_back(capture);
217+
}
218+
219+
// Clone the ops into each region in topological order.
220+
opsToClone = topologicalSort(opsToClone);
221+
for (Region *region : wsOp.getPartitionRegions()) {
222+
b.setInsertionPointToStart(&region->front());
223+
IRMapping mapping;
224+
for (Operation *op : opsToClone) {
225+
Value copy = b.clone(*op, mapping)->getResult(0);
226+
mapping.map(op->getResult(0), copy);
227+
replaceAllUsesInRegionWith(op->getResult(0), copy, *region);
228+
}
229+
}
230+
231+
// Replace the leaves with explicit captures.
232+
wsOp->insertOperands(wsOp.getNumOperands(), explicitCaptures);
233+
for (Region *region : wsOp.getPartitionRegions()) {
234+
for (Value capture : explicitCaptures) {
216235
BlockArgument arg =
217236
region->addArgument(capture.getType(), capture.getLoc());
218237
replaceAllUsesInRegionWith(capture, arg, *region);
219238
}
220239
}
221-
222240
return success();
223241
}
224242

test/TritonGPU/partition-loops.mlir

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ tt.func @trivial_tensor_captures(%arg0: f16, %lb: i32, %ub: i32, %step: i32) {
206206
// CHECK: ttg.warp_specialize(%arg1, %arg2, %arg3, %arg0)
207207
scf.for %i = %lb to %ub step %step : i32 {
208208
// CHECK: partition0(%arg4: i32, %arg5: i32, %arg6: i32, %arg7: f16) num_warps(4)
209-
// CHECK-NEXT: [[SPLAT:%.*]] = tt.splat %arg7 : f16 -> tensor<32xf16>
210209
// CHECK-NEXT: [[RANGE:%.*]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
210+
// CHECK-NEXT: [[SPLAT:%.*]] = tt.splat %arg7 : f16 -> tensor<32xf16>
211211
// CHECK-NEXT: scf.for
212212
// CHECK-NEXT: "use"([[RANGE]], [[SPLAT]])
213213
"use"(%0, %1) {ttg.partition = 1} : (tensor<256xi32>, tensor<32xf16>) -> ()
@@ -238,4 +238,24 @@ tt.func @dce_before_warp_allocation(%lb: i32, %ub: i32, %step: i32) {
238238
tt.return
239239
}
240240

241+
// CHECK-LABEL: @capture_order
242+
tt.func public @capture_order(%arg0: i32) {
243+
%c0_i32 = arith.constant 0 : i32
244+
%c1_i32 = arith.constant 1 : i32
245+
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked>
246+
%1 = arith.extsi %0 : tensor<4xi32, #blocked> to tensor<4xi64, #blocked>
247+
// CHECK: ttg.warp_specialize
248+
// CHECK: partition0
249+
// CHECK: [[VALUE:%.*]] = tt.make_range
250+
// CHECK-NEXT: [[EXT:%.*]] = arith.extsi [[VALUE]]
251+
// CHECK-NEXT: scf.for
252+
scf.for %arg1 = %c0_i32 to %arg0 step %c1_i32 : i32 {
253+
// CHECK-NEXT: "use"([[VALUE]])
254+
"use"(%0) : (tensor<4xi32, #blocked>) -> ()
255+
// CHECK-NEXT: "use"([[EXT]])
256+
"use"(%1) : (tensor<4xi64, #blocked>) -> ()
257+
} {ttg.partition.stages = [1 : i32, 0 : i32]}
258+
tt.return
259+
}
260+
241261
}

0 commit comments

Comments
 (0)