Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -460,17 +460,12 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason.
The compiler is still free to change it for better performance.
}];
let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr<UnitAttr>:$efficient_layout);
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)";
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
let hasVerifier = 1;
let builders = [
OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder),
[{
build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr());
}]>];
}

def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op,
}

LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
if (!op.getAllowReorder() || op.getEfficientLayout())
return failure();
return canonicalizeViewOrBroadcast(op, rewriter);
}
Expand Down
5 changes: 2 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2764,7 +2764,7 @@ struct CanonicalizeConvertFromReshape
return failure();
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
return failure();
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
if (!op.getAllowReorder() || op.getEfficientLayout())
return failure();

rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
Expand Down Expand Up @@ -2885,8 +2885,7 @@ struct CanonicalizeConvertFromConvert

// cvt(reshape) -> reshape
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
if (!reshape.getAllowReorder() ||
reshape.getEfficientLayout().has_value() ||
if (!reshape.getAllowReorder() || reshape.getEfficientLayout() ||
isExpensiveView(reshape.getSrc().getType(), op.getType()))
return failure();

Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,7 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
RankedTensorType newDstType =
RankedTensorType::get(reshapeDstType.getShape(),
reshapeDstType.getElementType(), targetEncoding);
return reshape.getAllowReorder() &&
!reshape.getEfficientLayout().has_value() &&
return reshape.getAllowReorder() && !reshape.getEfficientLayout() &&
!triton::gpu::isExpensiveView(reshape.getSrc().getType(),
newDstType);
}
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/intel/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-NEXT: [[STRUCT2:%.*]] = llvm.insertvalue [[ARG0_1]], [[STRUCT1]][1]
// CHECK-NEXT: [[T0:%.*]] = llvm.extractvalue [[STRUCT2]][0]
// CHECK-NEXT: [[T1:%.*]] = llvm.extractvalue [[STRUCT2]][1]
%0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
// CHECK: [[RES:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK-NEXT: [[RES1:%.*]] = llvm.insertvalue [[T0]], [[RES]][0]
// CHECK-NEXT: [[RES2:%.*]] = llvm.insertvalue [[T1]], [[RES1]][1]
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.mlir.undef
// CHECK: %[[T0:.*]] = llvm.extractvalue
// CHECK: %[[T1:.*]] = llvm.extractvalue
%0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T1]]
Expand Down
12 changes: 6 additions & 6 deletions test/Triton/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,15 @@ tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>, %arg1: tensor<1xf32>)

// CHECK-LABEL: @test_canonicalize_view
tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
%view0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<4x2xf32>
%view1 = tt.reshape %view0 {allow_reorder = true} : tensor<2x4xf32> -> tensor<4x2xf32>
%view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32>
%view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32>

%splat = tt.splat %arg1 : tensor<f32> -> tensor<8xf32>
// CHECK: %{{.*}} = tt.splat %arg1 : tensor<f32> -> tensor<2x2x2xf32>
%view2 = tt.reshape %splat {allow_reorder = true} : tensor<8xf32> -> tensor<2x2x2xf32>
%view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32>

%view3 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<8xf32>
%view3 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<8xf32>
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
%add = arith.addf %view3, %arg0 : tensor<8xf32>

Expand Down Expand Up @@ -329,7 +329,7 @@ tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x
%a = arith.constant dense<1.0> : tensor<1x128xf32>

// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32>
%b = tt.reshape %a {allow_reorder = true} : tensor<1x128xf32> -> tensor<16x8xf32>
%b = tt.reshape %a allow_reorder : tensor<1x128xf32> -> tensor<16x8xf32>

// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x128xf32>
%c = tt.broadcast %a : tensor<1x128xf32> -> tensor<16x128xf32>
Expand Down
2 changes: 1 addition & 1 deletion test/Triton/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) {

tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) {
// expected-error @+1 {{number of src and dst elements of reshape must be the same}}
%a = tt.reshape %arg0 {allow_reorder = false} : tensor<32x128xf16> -> tensor<64x32xf16>
%a = tt.reshape %arg0 : tensor<32x128xf16> -> tensor<64x32xf16>
tt.return
}

Expand Down
25 changes: 23 additions & 2 deletions test/Triton/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,29 @@ tt.func @inline_asm_scalar(%0: i32) {

// CHECK-LABEL: reshape
tt.func @reshape(%0: tensor<512xi32>) {
// CHECK: tt.reshape %{{.+}} {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32>
%1 = tt.reshape %0 {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32>
// CHECK: tt.reshape %{{.+}} : tensor<512xi32> -> tensor<16x32xi32>
%1 = tt.reshape %0 : tensor<512xi32> -> tensor<16x32xi32>
tt.return
}

// CHECK-LABEL: reshape_allow_reorder
tt.func @reshape_allow_reorder(%0: tensor<512xi32>) {
// CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32>
%1 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32>
tt.return
}

// CHECK-LABEL: reshape_efficient_layout
tt.func @reshape_efficient_layout(%0: tensor<512xi32>) {
// CHECK: tt.reshape %{{.+}} efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
%1 = tt.reshape %0 efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
tt.return
}

// CHECK-LABEL: reshape_allow_reorder_efficient_layout
tt.func @reshape_allow_reorder_efficient_layout(%0: tensor<512xi32>) {
// CHECK: tt.reshape %{{.+}} allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
%1 = tt.reshape %0 allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
tt.return
}

Expand Down
8 changes: 4 additions & 4 deletions test/TritonGPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// CHECK-LABEL: @test_canonicalize_convert_view
// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] {allow_reorder = true}
// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder
// CHECK: tt.return %[[V]]
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
Expand All @@ -13,7 +13,7 @@
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} {
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
%c = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2>
%r = tt.reshape %c {allow_reorder = true} : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1>
%r = tt.reshape %c allow_reorder : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1>
tt.return %r : tensor<4096xf32, #blocked1>
}
} // end module
Expand All @@ -25,15 +25,15 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) ->
// CHECK-LABEL: @test_canonicalize_convert_expensive_view
// CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32
// CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]]
// CHECK: %[[V:.+]] = tt.reshape %[[C]] {allow_reorder = true}
// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder
// CHECK: tt.return %[[V]]
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} {
tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
%c = triton_gpu.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2>
%r = tt.reshape %c {allow_reorder = true} : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1>
%r = tt.reshape %c allow_reorder : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1>
tt.return %r : tensor<4096xf32, #blocked1>
}
} // end module
Expand Down
10 changes: 5 additions & 5 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2097,7 +2097,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.func public @reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked3> {
// CHECK-NOT: triton_gpu.convert_layout
%a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1>
%b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2>
%b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2>
%c = triton_gpu.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3>
tt.return %c : tensor<32xf32, #blocked3>
}
Expand All @@ -2116,7 +2116,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: tt.reshape
// CHECK: triton_gpu.convert_layout
%a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1>
%b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2>
%b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2>
tt.return %b : tensor<32xf32, #blocked2>
}
}
Expand All @@ -2133,7 +2133,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: arith.truncf
// CHECK: triton_gpu.convert_layout
%a = tt.reshape %arg0 {allow_reorder = true, efficient_layout} : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1>
%a = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1>
%b = triton_gpu.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2>
%c = arith.truncf %b : tensor<32xf32, #blocked2> to tensor<32xf16, #blocked2>
tt.return %c : tensor<32xf16, #blocked2>
Expand Down Expand Up @@ -2536,9 +2536,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>>
%2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2>
%3 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<1x2x128xi32, #blocked2>
%4 = tt.reshape %3 {allow_reorder = false} : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1>
%4 = tt.reshape %3 : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1>
%5 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x64xi32, #blocked2>
%6 = tt.reshape %5 {allow_reorder = false} : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1>
%6 = tt.reshape %5 : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1>
%7 = arith.cmpi ne, %4, %cst : tensor<1x256xi32, #blocked1>
%8 = arith.select %7, %6, %cst : tensor<1x256xi1, #blocked1>, tensor<1x256xi32, #blocked1>
%9 = triton_gpu.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked>
Expand Down
2 changes: 1 addition & 1 deletion test/TritonGPU/loop-pipeline-hip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr<f32>, #blocked2>, tensor<64x8x32x!tt.ptr<f32>, #blocked1>) : i32 {
%39 = tt.load %arg9 : tensor<1x512x!tt.ptr<f32>, #blocked2>
%40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr<f32>, #blocked1>
%41 = tt.reshape %39 {allow_reorder = true} : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5>
%41 = tt.reshape %39 allow_reorder : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5>
%43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
%44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>
%45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked>
Expand Down
2 changes: 1 addition & 1 deletion test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
%84 = arith.sitofp %82 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked>
%85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3>
%86 = tt.trans %85 {order = array<i32: 0, 2, 1>} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4>
%87 = tt.reshape %86 {allow_reorder = false} : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5>
%87 = tt.reshape %86 : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5>
%88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma>
Expand Down
Loading