Skip to content

Commit b7c440b

Browse files
author
Si Yudong
authored
Handle op with multi results case in changeAndPropagateLayout (#4146)
This PR fixes issue #4000 by updating the changeAndPropagateLayout function to support operations with multiple results. - Allow operations to have one or more results instead of exactly one - Iterate over all results and update the type for those with tensor pointer types
1 parent 6f851e2 commit b7c440b

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

test/TritonIntelGPU/coalesce.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,3 +470,55 @@ module attributes {triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 :
470470
tt.return
471471
}
472472
}
473+
474+
// -----
475+
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
476+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
477+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
478+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 2, 1], order = [0, 1, 2]}>
479+
module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} {
480+
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
481+
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
482+
// CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
483+
// CHECK-DAG: [[BLOCKED_LAYOUT3:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 2, 1], order = [0, 1, 2]}>
484+
// CHECK: @triton_red_fused_prod_0
485+
tt.func public @triton_red_fused_prod_0(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) attributes {noinline = false} {
486+
%c4_i32 = arith.constant 4 : i32
487+
%cst = arith.constant dense<1.000000e+00> : tensor<1x4x4xf32, #blocked>
488+
%c0_i32 = arith.constant 0 : i32
489+
%c1_i64 = arith.constant 1 : i64
490+
%0 = arith.extsi %arg2 : i32 to i64
491+
%1 = arith.extsi %arg3 : i32 to i64
492+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, {{.*}} : <tensor<4x4xf32, [[BLOCKED_LAYOUT1]]>>
493+
%2 = tt.make_tensor_ptr %arg0, [%0, %0], [%1, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<4x4xf32, #blocked1>>
494+
%3 = tt.splat %arg5 : i32 -> tensor<1x1x4xi32, #blocked2>
495+
// CHECK: [[RES1:%.*]]:2 = scf.for {{.*}} iter_args([[ARG7:%.*]] = %cst, [[ARG8:%.*]] = [[PTR]]) -> (tensor<1x4x4xf32, [[BLOCKED_LAYOUT]]>, !tt.ptr<tensor<4x4xf32, [[BLOCKED_LAYOUT1]]>>)
496+
%4:2 = scf.for %arg6 = %c0_i32 to %arg4 step %c4_i32 iter_args(%arg7 = %cst, %arg8 = %2) -> (tensor<1x4x4xf32, #blocked>, !tt.ptr<tensor<4x4xf32, #blocked1>>) : i32 {
497+
// CHECK: [[RES2:%.*]]:2 = scf.for {{.*}} iter_args([[ARG10:%.*]] = [[ARG7:%.*]], [[ARG11:%.*]] = [[ARG8:%.*]]) -> (tensor<1x4x4xf32, [[BLOCKED_LAYOUT]]>, !tt.ptr<tensor<4x4xf32, [[BLOCKED_LAYOUT1]]>>)
498+
%5:2 = scf.for %arg9 = %c0_i32 to %arg5 step %c4_i32 iter_args(%arg10 = %arg7, %arg11 = %arg8) -> (tensor<1x4x4xf32, #blocked>, !tt.ptr<tensor<4x4xf32, #blocked1>>) : i32 {
499+
%7 = tt.splat %arg9 : i32 -> tensor<1x1x4xi32, #blocked2>
500+
%8 = arith.cmpi slt, %7, %3 : tensor<1x1x4xi32, #blocked2>
501+
// CHECK-DAG: [[LOAD:%.*]] = tt.load [[ARG11]] {{.*}} : !tt.ptr<tensor<4x4xf32, [[BLOCKED_LAYOUT1]]>>
502+
// CHECK: [[CONVERT_LAYOUT_0:%.*]] = ttg.convert_layout [[LOAD]] : tensor<4x4xf32, [[BLOCKED_LAYOUT1]]> -> tensor<4x4xf32, #ttg.slice<{dim = 0, parent = [[BLOCKED_LAYOUT3]]}>>
503+
// CHECK: [[CONVERT_LAYOUT_1:%.*]] = ttg.convert_layout {{.*}} : tensor<1x4x4xf32, [[BLOCKED_LAYOUT3]]> -> tensor<1x4x4xf32, [[BLOCKED_LAYOUT]]>
504+
// CHECK: [[CONVERT_LAYOUT_2:%.*]] = ttg.convert_layout {{.*}} : tensor<1x4x4xi1, [[BLOCKED_LAYOUT2]]> -> tensor<1x4x4xi1, [[BLOCKED_LAYOUT]]>
505+
%9 = tt.load %arg11 evictionPolicy = evict_first {boundaryCheck = array<i32: 0, 1>, padding = 1 : i32} : !tt.ptr<tensor<4x4xf32, #blocked1>>
506+
%10 = ttg.convert_layout %9 : tensor<4x4xf32, #blocked1> -> tensor<4x4xf32, #ttg.slice<{dim = 0, parent = #blocked3}>>
507+
%11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<4x4xf32, #ttg.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x4x4xf32, #blocked3>
508+
%12 = ttg.convert_layout %11 : tensor<1x4x4xf32, #blocked3> -> tensor<1x4x4xf32, #blocked>
509+
%13 = tt.broadcast %8 : tensor<1x1x4xi1, #blocked2> -> tensor<1x4x4xi1, #blocked2>
510+
%14 = ttg.convert_layout %13 : tensor<1x4x4xi1, #blocked2> -> tensor<1x4x4xi1, #blocked>
511+
%15 = arith.select %14, %12, %arg10 : tensor<1x4x4xi1, #blocked>, tensor<1x4x4xf32, #blocked>
512+
// CHECK-DAG: [[ADVANCE1:%.*]] = tt.advance [[ARG11]], {{.*}} : <tensor<4x4xf32, [[BLOCKED_LAYOUT1]]>>
513+
// CHECK: scf.yield {{.*}} : tensor<1x4x4xf32, [[BLOCKED_LAYOUT]]>, !tt.ptr<tensor<4x4xf32, [[BLOCKED_LAYOUT1]]>>
514+
%16 = tt.advance %arg11, [%c0_i32, %c4_i32] : <tensor<4x4xf32, #blocked1>>
515+
scf.yield %15, %16 : tensor<1x4x4xf32, #blocked>, !tt.ptr<tensor<4x4xf32, #blocked1>>
516+
}
517+
// CHECK-DAG: [[ADVANCE2:%.*]] = tt.advance [[RES2]]#1, {{.*}} : <tensor<4x4xf32, [[BLOCKED_LAYOUT1]]>>
518+
// CHECK: scf.yield [[RES2]]#0, [[ADVANCE2]] : tensor<1x4x4xf32, [[BLOCKED_LAYOUT]]>, !tt.ptr<tensor<4x4xf32, [[BLOCKED_LAYOUT1]]>>
519+
%6 = tt.advance %5#1, [%c4_i32, %c4_i32] : <tensor<4x4xf32, #blocked1>>
520+
scf.yield %5#0, %6 : tensor<1x4x4xf32, #blocked>, !tt.ptr<tensor<4x4xf32, #blocked1>>
521+
}
522+
tt.return
523+
}
524+
}

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,19 @@ struct CoalescePass
174174
// to its users.
175175
void changeAndPropagateLayout(Operation *op, Attribute layout,
176176
IRRewriter &rewriter) const {
177-
assert(op && op->getNumResults() == 1 &&
178-
"Expecting operation yielding a result");
177+
assert(op && op->getNumResults() != 0 &&
178+
"Expecting operation yielding results");
179179

180180
rewriter.modifyOpInPlace(op, [&]() {
181-
Value res = op->getOpResult(0);
182-
assert(tt::isTensorPointerType(res.getType()) &&
183-
"Expecting a block pointer");
184-
185-
auto ptrType = cast<tt::PointerType>(res.getType());
186-
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
187-
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
188-
ptrType.getAddressSpace()));
181+
for (Value res : op->getResults()) {
182+
if (!tt::isTensorPointerType(res.getType()))
183+
continue;
184+
185+
auto ptrType = cast<tt::PointerType>(res.getType());
186+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
187+
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
188+
ptrType.getAddressSpace()));
189+
}
189190
});
190191
LDBG("Coalesced op: " << *op);
191192

0 commit comments

Comments
 (0)