Skip to content

Commit 99abe0f

Browse files
committed
[Triton] Use UnitAttr in tt.reshape definition
Use `UnitAttr` in `tt.reshape` definition to make working with them and creating them easier. `allow_reorder` and `efficient_layout` are now `UnitAttr`, allowing dropping the additional constructor. Signed-off-by: victor-eds <[email protected]>
1 parent 6018c7b commit 99abe0f

File tree

15 files changed

+37
-44
lines changed

15 files changed

+37
-44
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,17 +460,12 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
460460
If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason.
461461
The compiler is still free to change it for better performance.
462462
}];
463-
let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr<UnitAttr>:$efficient_layout);
463+
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
464464
let results = (outs TT_Tensor:$result);
465-
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
465+
let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)";
466466
let hasCanonicalizeMethod = 1;
467467
let hasFolder = 1;
468468
let hasVerifier = 1;
469-
let builders = [
470-
OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder),
471-
[{
472-
build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr());
473-
}]>];
474469
}
475470

476471
def TT_BroadcastOp : TT_Op<"broadcast", [Pure,

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op,
678678
}
679679

680680
LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
681-
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
681+
if (!op.getAllowReorder() || op.getEfficientLayout())
682682
return failure();
683683
return canonicalizeViewOrBroadcast(op, rewriter);
684684
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2764,7 +2764,7 @@ struct CanonicalizeConvertFromReshape
27642764
return failure();
27652765
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
27662766
return failure();
2767-
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
2767+
if (!op.getAllowReorder() || op.getEfficientLayout())
27682768
return failure();
27692769

27702770
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
@@ -2885,8 +2885,7 @@ struct CanonicalizeConvertFromConvert
28852885

28862886
// cvt(reshape) -> reshape
28872887
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
2888-
if (!reshape.getAllowReorder() ||
2889-
reshape.getEfficientLayout().has_value() ||
2888+
if (!reshape.getAllowReorder() || reshape.getEfficientLayout() ||
28902889
isExpensiveView(reshape.getSrc().getType(), op.getType()))
28912890
return failure();
28922891

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,7 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
556556
RankedTensorType newDstType =
557557
RankedTensorType::get(reshapeDstType.getShape(),
558558
reshapeDstType.getElementType(), targetEncoding);
559-
return reshape.getAllowReorder() &&
560-
!reshape.getEfficientLayout().has_value() &&
559+
return reshape.getAllowReorder() && !reshape.getEfficientLayout() &&
561560
!triton::gpu::isExpensiveView(reshape.getSrc().getType(),
562561
newDstType);
563562
}

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
506506
// CHECK-NEXT: [[STRUCT2:%.*]] = llvm.insertvalue [[ARG0_1]], [[STRUCT1]][1]
507507
// CHECK-NEXT: [[T0:%.*]] = llvm.extractvalue [[STRUCT2]][0]
508508
// CHECK-NEXT: [[T1:%.*]] = llvm.extractvalue [[STRUCT2]][1]
509-
%0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
509+
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
510510
// CHECK: [[RES:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
511511
// CHECK-NEXT: [[RES1:%.*]] = llvm.insertvalue [[T0]], [[RES]][0]
512512
// CHECK-NEXT: [[RES2:%.*]] = llvm.insertvalue [[T1]], [[RES1]][1]

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
357357
// CHECK: llvm.mlir.undef
358358
// CHECK: %[[T0:.*]] = llvm.extractvalue
359359
// CHECK: %[[T1:.*]] = llvm.extractvalue
360-
%0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
360+
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
361361
// CHECK: llvm.mlir.undef
362362
// CHECK: llvm.insertvalue %[[T0]]
363363
// CHECK: llvm.insertvalue %[[T1]]

test/Triton/combine.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,15 @@ tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>, %arg1: tensor<1xf32>)
292292

293293
// CHECK-LABEL: @test_canonicalize_view
294294
tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
295-
%view0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<2x4xf32>
296-
// CHECK: %{{.*}} = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<4x2xf32>
297-
%view1 = tt.reshape %view0 {allow_reorder = true} : tensor<2x4xf32> -> tensor<4x2xf32>
295+
%view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32>
296+
// CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32>
297+
%view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32>
298298

299299
%splat = tt.splat %arg1 : tensor<f32> -> tensor<8xf32>
300300
// CHECK: %{{.*}} = tt.splat %arg1 : tensor<f32> -> tensor<2x2x2xf32>
301-
%view2 = tt.reshape %splat {allow_reorder = true} : tensor<8xf32> -> tensor<2x2x2xf32>
301+
%view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32>
302302

303-
%view3 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<8xf32>
303+
%view3 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<8xf32>
304304
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
305305
%add = arith.addf %view3, %arg0 : tensor<8xf32>
306306

@@ -329,7 +329,7 @@ tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x
329329
%a = arith.constant dense<1.0> : tensor<1x128xf32>
330330

331331
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32>
332-
%b = tt.reshape %a {allow_reorder = true} : tensor<1x128xf32> -> tensor<16x8xf32>
332+
%b = tt.reshape %a allow_reorder : tensor<1x128xf32> -> tensor<16x8xf32>
333333

334334
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x128xf32>
335335
%c = tt.broadcast %a : tensor<1x128xf32> -> tensor<16x128xf32>

test/Triton/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) {
3838

3939
tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) {
4040
// expected-error @+1 {{number of src and dst elements of reshape must be the same}}
41-
%a = tt.reshape %arg0 {allow_reorder = false} : tensor<32x128xf16> -> tensor<64x32xf16>
41+
%a = tt.reshape %arg0 allow_reorder : tensor<32x128xf16> -> tensor<64x32xf16>
4242
tt.return
4343
}
4444

test/Triton/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ tt.func @inline_asm_scalar(%0: i32) {
225225

226226
// CHECK-LABEL: reshape
227227
tt.func @reshape(%0: tensor<512xi32>) {
228-
// CHECK: tt.reshape %{{.+}} {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32>
229-
%1 = tt.reshape %0 {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32>
228+
// CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32>
229+
%1 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32>
230230
tt.return
231231
}
232232

test/TritonGPU/canonicalize.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// CHECK-LABEL: @test_canonicalize_convert_view
55
// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32
66
// CHECK-NOT: triton_gpu.convert_layout
7-
// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] {allow_reorder = true}
7+
// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder
88
// CHECK: tt.return %[[V]]
99
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
1010
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
@@ -13,7 +13,7 @@
1313
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} {
1414
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
1515
%c = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2>
16-
%r = tt.reshape %c {allow_reorder = true} : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1>
16+
%r = tt.reshape %c allow_reorder : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1>
1717
tt.return %r : tensor<4096xf32, #blocked1>
1818
}
1919
} // end module
@@ -25,15 +25,15 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) ->
2525
// CHECK-LABEL: @test_canonicalize_convert_expensive_view
2626
// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32
2727
// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]]
28-
// CHECK: %[[V:.+]] = tt.reshape %[[C]] {allow_reorder = true}
28+
// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder
2929
// CHECK: tt.return %[[V]]
3030
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
3131
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
3232
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
3333
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} {
3434
tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
3535
%c = triton_gpu.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2>
36-
%r = tt.reshape %c {allow_reorder = true} : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1>
36+
%r = tt.reshape %c allow_reorder : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1>
3737
tt.return %r : tensor<4096xf32, #blocked1>
3838
}
3939
} // end module

0 commit comments

Comments
 (0)