Skip to content

Commit 5a6cf81

Browse files
committed
Fix functional problem and add lit test
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 041e2da commit 5a6cf81

File tree

2 files changed

+128
-7
lines changed

2 files changed

+128
-7
lines changed

test/TritonIntelGPU/coalesce.mlir

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
134134

135135
// -----
136136

137+
// COM: Test coalescing on blocked pointers: coalescable load using block pointer in a SCF for loop.
138+
137139
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}>
138140
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
139141
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -225,6 +227,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
225227

226228
// -----
227229

230+
// COM: Test coalescing on blocked pointers: loop results used by another loop.
231+
228232
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}>
229233
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
230234
#dot2 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>
@@ -256,19 +260,79 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
256260
%20 = arith.extsi %arg11 : i32 to i64
257261
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : <tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>
258262
%21 = tt.make_tensor_ptr %19, [%c64_i64, %13], [%c1_i64, %20], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x32xf8E5M2, #dot2>>
259-
// CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args(%arg6 = %cst, %arg7 = [[PTR1]]) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>)
263+
// CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[ARG1:%.*]] = %cst, [[ARG2:%.*]] = [[PTR1]]) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>)
260264
%33:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg22 = %cst_1, %arg23 = %21) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr<tensor<64x32xf8E5M2, #dot2>>) : i32 {
261-
// CHECK: [[LOAD:%.*]] = tt.load %arg7 : !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>
265+
// CHECK: [[LOAD:%.*]] = tt.load [[ARG2]] : !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>
262266
// CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]> -> tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
263-
// CHECK-NEXT: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<64x32xf8E5M2, #blocked>>
267+
// CHECK-NEXT: scf.yield [[ARG1]], [[ARG2]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<64x32xf8E5M2, #blocked>>
264268
%load = tt.load %arg23 : !tt.ptr<tensor<64x32xf8E5M2, #dot2>>
265269
scf.yield %arg22, %arg23 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr<tensor<64x32xf8E5M2, #dot2>>
266270
}
267-
// CHECK: scf.for {{.*}} iter_args(%arg6 = [[RES]]#0, %arg7 = [[RES]]#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>)
271+
// CHECK: scf.for {{.*}} iter_args([[ARG1:%.*]] = [[RES]]#0, [[ARG2:%.*]] = [[RES]]#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>)
268272
%34:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg22 = %33#0, %arg23 = %33#1) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr<tensor<64x32xf8E5M2, #dot2>>) : i32 {
269-
// CHECK: scf.yield %arg6, %arg7 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>
273+
// CHECK: scf.yield [[ARG1]], [[ARG2]] : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>
270274
scf.yield %arg22, %arg23 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #dpas}>>, !tt.ptr<tensor<64x32xf8E5M2, #dot2>>
271275
}
272276
tt.return
273277
}
274278
}
279+
280+
// -----
281+
282+
// COM: Test coalescing on blocked pointers: loop with 2 output blocked pointers.
283+
284+
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}>
285+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
286+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
287+
// CHECK: [[BLOCKED_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [4, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
288+
// CHECK: @test_block_ptrs
289+
tt.func public @test_block_ptrs(%arg0: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32, %arg11: i32 {tt.divisibility = 16 : i32}, %arg14: i32, %arg19: i32, %arg20: i32) {
290+
%c32_i32 = arith.constant 32 : i32
291+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma>
292+
%c64_i32 = arith.constant 64 : i32
293+
%c64_i64 = arith.constant 64 : i64
294+
%c1_i64 = arith.constant 1 : i64
295+
%c0_i32 = arith.constant 0 : i32
296+
%0 = tt.get_program_id x : i32
297+
%1 = tt.get_program_id y : i32
298+
%2 = arith.divsi %1, %arg19 : i32
299+
%3 = arith.remsi %1, %arg19 : i32
300+
%4 = arith.extsi %2 : i32 to i64
301+
%5 = arith.extsi %arg6 : i32 to i64
302+
%6 = arith.muli %4, %5 : i64
303+
%7 = arith.extsi %3 : i32 to i64
304+
%8 = arith.extsi %arg7 : i32 to i64
305+
%9 = arith.muli %7, %8 : i64
306+
%10 = arith.addi %6, %9 : i64
307+
%11 = tt.addptr %arg0, %10 : !tt.ptr<f8E5M2>, i64
308+
%12 = arith.muli %0, %c64_i32 : i32
309+
%13 = arith.extsi %arg20 : i32 to i64
310+
%14 = arith.extsi %arg8 : i32 to i64
311+
%15 = tt.make_tensor_ptr %11, [%13, %c64_i64], [%14, %c1_i64], [%12, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>
312+
%16 = tt.addptr %arg2, %10 : !tt.ptr<f8E5M2>, i64
313+
%17 = arith.extsi %arg14 : i32 to i64
314+
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
315+
%18 = tt.make_tensor_ptr %16, [%13, %c64_i64], [%c1_i64, %17], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<32x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
316+
%19 = tt.addptr %arg1, %10 : !tt.ptr<f8E5M2>, i64
317+
%20 = arith.extsi %arg11 : i32 to i64
318+
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}} : <tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>
319+
%21 = tt.make_tensor_ptr %19, [%c64_i64, %13], [%c1_i64, %20], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
320+
%32 = tt.load %15 : !tt.ptr<tensor<64x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>>
321+
// CHECK: scf.for {{.*}} iter_args([[ARG1:%.*]] = [[PTR2]], [[ARG2:%.*]] = [[PTR1]]) -> (!tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<32x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)
322+
%35:2 = scf.for %arg21 = %c0_i32 to %12 step %c32_i32 iter_args(%arg25 = %21, %arg26 = %18) -> (!tt.ptr<tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>) : i32 {
323+
// CHECK: [[LOAD:%.*]] = tt.load [[ARG1]] : !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>
324+
// CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]> -> tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
325+
%58 = tt.load %arg25 : !tt.ptr<tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
326+
%59 = tt.fp_to_fp %32 : tensor<64x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
327+
%60 = tt.fp_to_fp %58 : tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
328+
%61 = tt.dot %59, %60, %cst_2, inputPrecision = tf32 : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
329+
// CHECK-DAG: [[ADVANCE1:%.*]] = tt.advance [[ARG1]], {{.*}} : <tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>
330+
// CHECK-DAG: [[ADVANCE2:%.*]] = tt.advance [[ARG2]], {{.*}} : <tensor<32x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
331+
// CHECK-NEXT: scf.yield [[ADVANCE1]], [[ADVANCE2]] : !tt.ptr<tensor<64x32xf8E5M2, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<32x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
332+
%84 = tt.advance %arg26, [%c32_i32, %c0_i32] : <tensor<32x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
333+
%85 = tt.advance %arg25, [%c0_i32, %c32_i32] : <tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
334+
scf.yield %85, %84 : !tt.ptr<tensor<64x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<32x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
335+
}
336+
tt.return
337+
}
338+
}

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace {
3030

3131
struct CoalescePass
3232
: public ttgi::impl::TritonIntelGPUCoalesceBase<CoalescePass> {
33+
private:
3334
void
3435
setCoalescedEncoding(tt::intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
3536
Operation *op, int numWarps, int threadsPerWarp,
@@ -180,7 +181,29 @@ struct CoalescePass
180181
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
181182
// Modify and propagate the result of the enclosing loop.
182183
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
183-
changeAndPropagateLayout(forOp, layout, rewriter);
184+
185+
rewriter.modifyOpInPlace(forOp, [&]() {
186+
for (auto [opType, res] :
187+
llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) {
188+
if (opType == res.getType())
189+
continue;
190+
191+
assert(tt::isTensorPointerType(res.getType()) &&
192+
tt::isTensorPointerType(opType) &&
193+
"Expecting blocked pointers");
194+
assert(cast<RankedTensorType>(
195+
cast<tt::PointerType>(opType).getPointeeType())
196+
.getEncoding() == layout &&
197+
"Unexpected layout");
198+
199+
auto resType = cast<tt::PointerType>(res.getType());
200+
auto tensorType = cast<RankedTensorType>(resType.getPointeeType());
201+
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
202+
resType.getAddressSpace()));
203+
}
204+
});
205+
206+
propagateLayout(forOp, layout, rewriter);
184207
continue;
185208
}
186209

@@ -204,7 +227,29 @@ struct CoalescePass
204227
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
205228
// Modify and propagate the result of the enclosing loop.
206229
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
207-
changeAndPropagateLayout(forOp, layout, rewriter);
230+
231+
rewriter.modifyOpInPlace(forOp, [&]() {
232+
for (auto [opType, res] :
233+
llvm::zip(yieldOp->getOperandTypes(), forOp.getResults())) {
234+
if (opType == res.getType())
235+
continue;
236+
237+
assert(tt::isTensorPointerType(res.getType()) &&
238+
tt::isTensorPointerType(opType) &&
239+
"Expecting blocked pointers");
240+
assert(cast<RankedTensorType>(
241+
cast<tt::PointerType>(opType).getPointeeType())
242+
.getEncoding() == layout &&
243+
"Unexpected layout");
244+
245+
auto resType = cast<tt::PointerType>(res.getType());
246+
auto tensorType = cast<RankedTensorType>(resType.getPointeeType());
247+
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
248+
resType.getAddressSpace()));
249+
}
250+
});
251+
252+
propagateLayout(forOp, layout, rewriter);
208253
continue;
209254
}
210255

@@ -248,6 +293,10 @@ struct CoalescePass
248293
if (!tt::isTensorPointerType(res.getType()))
249294
continue;
250295

296+
// Problem: if the operation is a for loop we cannot modify the layout
297+
// of all the tensor ptr results, we need to modify only the one used by
298+
// the yield operation.
299+
251300
auto ptrType = cast<tt::PointerType>(res.getType());
252301
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
253302
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
@@ -319,6 +368,7 @@ struct CoalescePass
319368
assert(succeeded(verify(newOp)) && "Operation verification failed");
320369
}
321370

371+
public:
322372
void runOnOperation() override {
323373
// Run axis info analysis
324374
ModuleOp moduleOp = getOperation();
@@ -340,6 +390,13 @@ struct CoalescePass
340390
if (!refTensorType || !refTensorType.getEncoding())
341391
return;
342392

393+
// static int n = 0;
394+
// if (tt::isTensorPointerType(ptr.getType()))
395+
// n++;
396+
397+
// if (n != 2)
398+
// return;
399+
343400
int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp);
344401
int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp);
345402
setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp,

0 commit comments

Comments
 (0)