Skip to content

Commit 8aabcc7

Browse files
committed
Adapt to reshape changes
1 parent 94ba49c commit 8aabcc7

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

test/TritonIntelGPU/optimize-reduction.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
318318
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
319319
// CHECK: tt.func @test(
320320
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x64xf32, #[[$DPAS]]>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> {
321-
// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<128x64xf32, #[[$DPAS]]> -> tensor<16x2x4x16x2x2x1xf32, #[[$BLOCKED_EW]]>
321+
// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<128x64xf32, #[[$DPAS]]> -> tensor<16x2x4x16x2x2x1xf32, #[[$BLOCKED_EW]]>
322322
// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({
323323
// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32):
324324
// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32
@@ -330,14 +330,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
330330
// CHECK: tt.reduce.return %[[VAL_9]] : f32
331331
// CHECK: }) : (tensor<16x2x4x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>>) -> tensor<16x2x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>>
332332
// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x2x4x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> -> tensor<16x2x4x16x2xf32, #[[$BLOCKED_TRANS]]>
333-
// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x2x4x16x2xf32, #[[$BLOCKED_TRANS]]> -> tensor<16x2x4x32xf32, #[[$BLOCKED_RED]]>
333+
// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x2x4x16x2xf32, #[[$BLOCKED_TRANS]]> -> tensor<16x2x4x32xf32, #[[$BLOCKED_RED]]>
334334
// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({
335335
// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
336336
// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32
337337
// CHECK: tt.reduce.return %[[VAL_14]] : f32
338338
// CHECK: }) : (tensor<16x2x4x32xf32, #[[$BLOCKED_RED]]>) -> tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>>
339339
// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> -> tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>>
340-
// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>>
340+
// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x2x4xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>>
341341
// CHECK: tt.return %[[VAL_16]] : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>>
342342
// CHECK: }
343343
tt.func @test(%arg0: tensor<128x64xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
@@ -351,7 +351,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
351351

352352
// CHECK: tt.func @test_repeat_layout(
353353
// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x64xf32, #[[$DPAS]]>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>> {
354-
// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] {allow_reorder = true, efficient_layout} : tensor<256x64xf32, #[[$DPAS]]> -> tensor<16x2x8x16x2x2x1xf32, #[[$BLOCKED_EW]]>
354+
// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<256x64xf32, #[[$DPAS]]> -> tensor<16x2x8x16x2x2x1xf32, #[[$BLOCKED_EW]]>
355355
// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 6 : i32}> ({
356356
// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32):
357357
// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32
@@ -363,14 +363,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
363363
// CHECK: tt.reduce.return %[[VAL_9]] : f32
364364
// CHECK: }) : (tensor<16x2x8x16x2x2xf32, #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>>) -> tensor<16x2x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>>
365365
// CHECK: %[[CONV:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16x2x8x16x2xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #[[$BLOCKED_EW]]}>}>> -> tensor<16x2x8x16x2xf32, #[[$BLOCKED_TRANS]]>
366-
// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] {allow_reorder = true, efficient_layout} : tensor<16x2x8x16x2xf32, #[[$BLOCKED_TRANS]]> -> tensor<16x2x8x32xf32, #[[$BLOCKED_RED]]>
366+
// CHECK: %[[VAL_10:.*]] = tt.reshape %[[CONV]] allow_reorder efficient_layout : tensor<16x2x8x16x2xf32, #[[$BLOCKED_TRANS]]> -> tensor<16x2x8x32xf32, #[[$BLOCKED_RED]]>
367367
// CHECK: %[[VAL_11:.*]] = "tt.reduce"(%[[VAL_10]]) <{axis = 3 : i32}> ({
368368
// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
369369
// CHECK: %[[VAL_14:.*]] = arith.maxnumf %[[VAL_12]], %[[VAL_13]] : f32
370370
// CHECK: tt.reduce.return %[[VAL_14]] : f32
371371
// CHECK: }) : (tensor<16x2x8x32xf32, #[[$BLOCKED_RED]]>) -> tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>>
372372
// CHECK: %[[VAL_15:.*]] = triton_gpu.convert_layout %[[VAL_11]] : tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_RED]]}>> -> tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>>
373-
// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] {allow_reorder = true, efficient_layout} : tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>>
373+
// CHECK: %[[VAL_16:.*]] = tt.reshape %[[VAL_15]] allow_reorder efficient_layout : tensor<16x2x8xf32, #triton_gpu.slice<{dim = 3, parent = #[[$BLOCKED_FINAL]]}>> -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>>
374374
// CHECK: tt.return %[[VAL_16]] : tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #[[$DPAS]]}>>
375375
// CHECK: }
376376
tt.func @test_repeat_layout(%arg0: tensor<256x64xf32, #mma>) -> tensor<256xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> {

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,6 @@ static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
3131
return rewriter.getAttr<CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
3232
}
3333

34-
static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc,
35-
Type type, Value val) {
36-
auto reshapeOp =
37-
rewriter.create<ReshapeOp>(loc, type, val, /*allow_reorder=*/true);
38-
reshapeOp.setEfficientLayout(true);
39-
return reshapeOp;
40-
}
41-
4234
// clang-format off
4335
/// Optimize reduction with DPAS-encoded input.
4436
///
@@ -296,7 +288,9 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
296288

297289
// Although this is a NOP, we have to pass allow_reorder=true as static
298290
// analysis will fail to infer it.
299-
return createReshapeForReduction(rewriter, op.getLoc(), type, val);
291+
return rewriter.create<ReshapeOp>(op.getLoc(), type, val,
292+
/*allow_reorder=*/true,
293+
/*efficient_layout=*/true);
300294
}
301295

302296
Value performReduction(ReduceOp op, PatternRewriter &rewriter, Value val,
@@ -380,7 +374,9 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
380374

381375
// Although this is a NOP, we have to pass allow_reorder=true as static
382376
// analysis will fail to infer it.
383-
return createReshapeForReduction(rewriter, op.getLoc(), type, val);
377+
return rewriter.create<ReshapeOp>(op.getLoc(), type, val,
378+
/*allow_reorder=*/true,
379+
/*efficient_layout=*/true);
384380
}
385381

386382
Value performFinalReduction(ReduceOp op, PatternRewriter &rewriter,
@@ -422,8 +418,10 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
422418

423419
Value reshapeToOriginalType(ReduceOp op, PatternRewriter &rewriter,
424420
Value val) const {
425-
return createReshapeForReduction(rewriter, op.getLoc(),
426-
op.getResult().front().getType(), val);
421+
return rewriter.create<ReshapeOp>(op.getLoc(),
422+
op.getResult().front().getType(), val,
423+
/*allow_reorder=*/true,
424+
/*efficient_layout=*/true);
427425
}
428426
};
429427

0 commit comments

Comments
 (0)