Skip to content

Commit 7a38bd8

Browse files
authored
Fix tritonintegpu-pipeline pass on block ptr example. (#3815)
Fixes issue #3810 --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent da26bee commit 7a38bd8

File tree

2 files changed

+112
-3
lines changed

2 files changed

+112
-3
lines changed

test/TritonIntelGPU/loop-pipeline.mlir

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,109 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
262262
tt.return
263263
}
264264
}
265+
266+
// -----
267+
268+
// COM: Reproducer for issue #3810.
269+
270+
// CHECK: #[[BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
271+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
272+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
273+
#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
274+
#blocked3 = #ttg.blocked<{sizePerThread = [4, 1, 4], threadsPerWarp = [1, 2, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
275+
#blocked4 = #ttg.blocked<{sizePerThread = [4, 4, 1], threadsPerWarp = [1, 16, 2], warpsPerCTA = [4, 1, 1], order = [1, 2, 0]}>
276+
#blocked5 = #ttg.blocked<{sizePerThread = [1, 8, 2], threadsPerWarp = [4, 8, 1], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
277+
module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} {
278+
tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) {
279+
// CHECK: tt.func public @matmul_kernel_descriptor_persistent([[PARAM_0:%.*]]: !tt.ptr<f16> {{.*}}, [[PARAM_1:%.*]]: !tt.ptr<f16> {{.*}}, [[PARAM_2:%.*]]: !tt.ptr<f16> {{.*}}, [[PARAM_3:%.*]]: i32 {{.*}}, [[PARAM_4:%.*]]: i32 {{.*}}, [[PARAM_5:%.*]]: i32 {{.*}})
280+
%c448_i32 = arith.constant 448 : i32
281+
%c8_i32 = arith.constant 8 : i32
282+
%c128_i32 = arith.constant 128 : i32
283+
%c64_i32 = arith.constant 64 : i32
284+
%c1_i64 = arith.constant 1 : i64
285+
%c0_i32 = arith.constant 0 : i32
286+
%c1_i32 = arith.constant 1 : i32
287+
%c127_i32 = arith.constant 127 : i32
288+
%c63_i32 = arith.constant 63 : i32
289+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
290+
%0 = tt.get_program_id x : i32
291+
%1 = arith.addi %arg3, %c127_i32 : i32
292+
%2 = arith.divsi %1, %c128_i32 : i32
293+
%3 = arith.addi %arg4, %c127_i32 : i32
294+
%4 = arith.divsi %3, %c128_i32 : i32
295+
%5 = arith.addi %arg5, %c63_i32 : i32
296+
%6 = arith.divsi %5, %c64_i32 : i32
297+
%7 = arith.muli %2, %4 : i32
298+
%8 = arith.extsi %arg5 : i32 to i64
299+
%9 = arith.extsi %arg4 : i32 to i64
300+
%10 = arith.subi %0, %c448_i32 : i32
301+
%11 = arith.muli %4, %c8_i32 : i32
302+
%12 = arith.extsi %arg3 : i32 to i64
303+
// CHECK: scf.for %[[OUTER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (i32)
304+
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
305+
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR1]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
306+
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
307+
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR2]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
308+
// CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
309+
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR3]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
310+
// CHECK: [[PTR4:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
311+
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR4]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
312+
// CHECK-NEXT: scf.for %[[INNER_IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x128xf32, #blocked>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #blocked1>>, !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>)
313+
// CHECK: [[PTR5:%.*]] = tt.make_tensor_ptr [[PARAM_0]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
314+
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR5]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
315+
// CHECK: [[PTR6:%.*]] = tt.make_tensor_ptr [[PARAM_1]], {{.*}} : <tensor<128x64xf16, #[[BLOCKED1]]>>
316+
// CHECK-NEXT: triton_intel_gpu.prefetch [[PTR6]] {boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #[[BLOCKED1]]>>
317+
%13 = scf.for %arg6 = %0 to %7 step %c448_i32 iter_args(%arg7 = %10) -> (i32) : i32 {
318+
%14 = arith.divsi %arg6, %11 : i32
319+
%15 = arith.muli %14, %c8_i32 : i32
320+
%16 = arith.subi %2, %15 : i32
321+
%17 = arith.minsi %16, %c8_i32 : i32
322+
%18 = arith.remsi %arg6, %17 : i32
323+
%19 = arith.addi %15, %18 : i32
324+
%20 = arith.remsi %arg6, %11 : i32
325+
%21 = arith.divsi %20, %17 : i32
326+
%22 = arith.muli %19, %c128_i32 : i32
327+
%23 = arith.muli %21, %c128_i32 : i32
328+
%24 = scf.for %arg8 = %c0_i32 to %6 step %c1_i32 iter_args(%arg9 = %cst) -> (tensor<128x128xf32, #blocked>) : i32 {
329+
%44 = arith.muli %arg8, %c64_i32 : i32
330+
%45 = tt.make_tensor_ptr %arg0, [%12, %8], [%8, %c1_i64], [%22, %44] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
331+
%46 = tt.load %45 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>>
332+
%47 = tt.make_tensor_ptr %arg1, [%9, %8], [%8, %c1_i64], [%23, %44] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
333+
%48 = tt.load %47 {triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #blocked1>>
334+
%49 = tt.trans %48 {order = array<i32: 1, 0>} : tensor<128x64xf16, #blocked1> -> tensor<64x128xf16, #blocked2>
335+
%50 = tt.fp_to_fp %46 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf32, #blocked1>
336+
%51 = ttg.convert_layout %50 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
337+
%52 = tt.fp_to_fp %49 : tensor<64x128xf16, #blocked2> -> tensor<64x128xf32, #blocked2>
338+
%53 = ttg.convert_layout %52 : tensor<64x128xf32, #blocked2> -> tensor<64x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
339+
%54 = tt.dot %51, %53, %arg9, inputPrecision = tf32 : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked>
340+
scf.yield %54 : tensor<128x128xf32, #blocked>
341+
}
342+
%25 = arith.addi %arg7, %c448_i32 : i32
343+
%26 = arith.divsi %25, %11 : i32
344+
%27 = arith.muli %26, %c8_i32 : i32
345+
%28 = arith.subi %2, %27 : i32
346+
%29 = arith.minsi %28, %c8_i32 : i32
347+
%30 = arith.remsi %25, %29 : i32
348+
%31 = arith.addi %27, %30 : i32
349+
%32 = arith.remsi %25, %11 : i32
350+
%33 = arith.divsi %32, %29 : i32
351+
%34 = arith.muli %31, %c128_i32 : i32
352+
%35 = arith.muli %33, %c128_i32 : i32
353+
%36 = tt.reshape %24 : tensor<128x128xf32, #blocked> -> tensor<128x2x64xf32, #blocked3>
354+
%37 = tt.trans %36 {order = array<i32: 0, 2, 1>} : tensor<128x2x64xf32, #blocked3> -> tensor<128x64x2xf32, #blocked4>
355+
%38 = ttg.convert_layout %37 : tensor<128x64x2xf32, #blocked4> -> tensor<128x64x2xf32, #blocked5>
356+
%outLHS, %outRHS = tt.split %38 : tensor<128x64x2xf32, #blocked5> -> tensor<128x64xf32, #blocked1>
357+
%39 = arith.truncf %outLHS : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
358+
%40 = tt.make_tensor_ptr %arg2, [%12, %9], [%9, %c1_i64], [%34, %35] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
359+
tt.store %40, %39 : !tt.ptr<tensor<128x64xf16, #blocked1>>
360+
%41 = arith.truncf %outRHS : tensor<128x64xf32, #blocked1> to tensor<128x64xf16, #blocked1>
361+
%42 = arith.addi %35, %c64_i32 : i32
362+
%43 = tt.make_tensor_ptr %arg2, [%12, %9], [%9, %c1_i64], [%34, %42] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #blocked1>>
363+
tt.store %43, %41 : !tt.ptr<tensor<128x64xf16, #blocked1>>
364+
scf.yield %25 : i32
365+
} {tt.flatten}
366+
tt.return
367+
}
368+
369+
370+
}

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,12 @@ static Operation *getDefOp(Value v, Operation *op, bool includeArg) {
180180
if (!seen.insert(v).second)
181181
break;
182182
if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) {
183-
auto yieldOp = op->getBlock()->getTerminator();
184-
v = yieldOp->getOperand(arg.getArgNumber() - 1);
185-
continue;
183+
Operation *termOp = op->getBlock()->getTerminator();
184+
if (auto yieldOp = dyn_cast<scf::YieldOp>(termOp)) {
185+
v = yieldOp->getOperand(arg.getArgNumber() - 1);
186+
continue;
187+
}
188+
break;
186189
}
187190
break;
188191
}

0 commit comments

Comments
 (0)