From 99abe0fa89ce8413698f65f80887c685d7d5241f Mon Sep 17 00:00:00 2001 From: victor-eds Date: Wed, 16 Oct 2024 10:29:10 +0100 Subject: [PATCH 1/2] [Triton] Use `UnitAttr` in `tt.reshape` definition Use `UnitAttr` in `tt.reshape` definition to make working with them and creating them easier. `allow_reorder` and `efficient_layout` are now `UnitAttr`, allowing dropping the additional constructor. Signed-off-by: victor-eds --- include/triton/Dialect/Triton/IR/TritonOps.td | 9 ++------- lib/Dialect/Triton/IR/Ops.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 ++--- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 3 +-- test/Conversion/intel/tritongpu_to_gen.mlir | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 2 +- test/Triton/combine.mlir | 12 ++++++------ test/Triton/invalid.mlir | 2 +- test/Triton/ops.mlir | 4 ++-- test/TritonGPU/canonicalize.mlir | 8 ++++---- test/TritonGPU/combine.mlir | 10 +++++----- test/TritonGPU/loop-pipeline-hip.mlir | 2 +- test/TritonGPU/loop-pipeline.mlir | 2 +- test/TritonGPU/optimize-locality.mlir | 12 ++++++------ test/TritonIntelGPU/combine.mlir | 6 +++--- 15 files changed, 37 insertions(+), 44 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 66946c20cc..a8358d968c 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -460,17 +460,12 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure, If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. The compiler is still free to change it for better performance. }]; - let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr:$efficient_layout); + let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout); let results = (outs TT_Tensor:$result); - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)"; let hasCanonicalizeMethod = 1; let hasFolder = 1; let hasVerifier = 1; - let builders = [ - OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder), - [{ - build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr()); - }]>]; } def TT_BroadcastOp : TT_Op<"broadcast", [Pure, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 1240caebe2..c2c057f42c 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -678,7 +678,7 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op, } LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { - if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + if (!op.getAllowReorder() || op.getEfficientLayout()) return failure(); return canonicalizeViewOrBroadcast(op, rewriter); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 378a501175..24b24bb701 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2764,7 +2764,7 @@ struct CanonicalizeConvertFromReshape return failure(); if (isExpensiveView(convert.getSrc().getType(), op.getType())) return failure(); - if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + if (!op.getAllowReorder() || op.getEfficientLayout()) return failure(); rewriter.replaceOpWithNewOp( @@ -2885,8 +2885,7 @@ struct CanonicalizeConvertFromConvert // cvt(reshape) -> reshape if (auto reshape = dyn_cast(arg)) { - if (!reshape.getAllowReorder() || - reshape.getEfficientLayout().has_value() || + if (!reshape.getAllowReorder() || reshape.getEfficientLayout() || isExpensiveView(reshape.getSrc().getType(), op.getType())) return failure(); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 91acba38bf..4ef9d1cd1d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -556,8 +556,7 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { RankedTensorType newDstType = RankedTensorType::get(reshapeDstType.getShape(), reshapeDstType.getElementType(), targetEncoding); - return reshape.getAllowReorder() && - !reshape.getEfficientLayout().has_value() && + return reshape.getAllowReorder() && !reshape.getEfficientLayout() && !triton::gpu::isExpensiveView(reshape.getSrc().getType(), newDstType); } diff --git a/test/Conversion/intel/tritongpu_to_gen.mlir b/test/Conversion/intel/tritongpu_to_gen.mlir index 4d77fe657e..28b4a81e4f 100644 --- a/test/Conversion/intel/tritongpu_to_gen.mlir +++ b/test/Conversion/intel/tritongpu_to_gen.mlir @@ -506,7 +506,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: [[STRUCT2:%.*]] = llvm.insertvalue [[ARG0_1]], [[STRUCT1]][1] // CHECK-NEXT: [[T0:%.*]] = llvm.extractvalue [[STRUCT2]][0] // CHECK-NEXT: [[T1:%.*]] = llvm.extractvalue [[STRUCT2]][1] - %0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> + %0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> // CHECK: [[RES:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK-NEXT: [[RES1:%.*]] = llvm.insertvalue [[T0]], [[RES]][0] // CHECK-NEXT: [[RES2:%.*]] = llvm.insertvalue [[T1]], [[RES1]][1] diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 76e85e3c7a..e2f43f4ba6 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -357,7 +357,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.undef // CHECK: %[[T0:.*]] = llvm.extractvalue // CHECK: %[[T1:.*]] = llvm.extractvalue - %0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> + %0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> // CHECK: llvm.mlir.undef // CHECK: llvm.insertvalue %[[T0]] // CHECK: llvm.insertvalue %[[T1]] diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 2197b18738..41a3ba15a8 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -292,15 +292,15 @@ tt.func @test_canonicalize_expand_dims(%arg0: tensor, %arg1: tensor<1xf32>) // CHECK-LABEL: @test_canonicalize_view tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) { - %view0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<2x4xf32> - // CHECK: %{{.*}} = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<4x2xf32> - %view1 = tt.reshape %view0 {allow_reorder = true} : tensor<2x4xf32> -> tensor<4x2xf32> + %view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32> + // CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32> + %view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32> %splat = tt.splat %arg1 : tensor -> tensor<8xf32> // CHECK: %{{.*}} = tt.splat %arg1 : tensor -> tensor<2x2x2xf32> - %view2 = tt.reshape %splat {allow_reorder = true} : tensor<8xf32> -> tensor<2x2x2xf32> + %view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32> - %view3 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<8xf32> + %view3 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<8xf32> // CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32> %add = arith.addf %view3, %arg0 : tensor<8xf32> @@ -329,7 +329,7 @@ tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x %a = arith.constant dense<1.0> : tensor<1x128xf32> // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32> - %b = tt.reshape %a {allow_reorder = true} : tensor<1x128xf32> -> tensor<16x8xf32> + %b = tt.reshape %a allow_reorder : tensor<1x128xf32> -> tensor<16x8xf32> // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x128xf32> %c = tt.broadcast %a : tensor<1x128xf32> -> tensor<16x128xf32> diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 6eee82fec7..c35395ae20 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -38,7 +38,7 @@ tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) { tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) { // expected-error @+1 {{number of src and dst elements of reshape must be the same}} - %a = tt.reshape %arg0 {allow_reorder = false} : tensor<32x128xf16> -> tensor<64x32xf16> + %a = tt.reshape %arg0 allow_reorder : tensor<32x128xf16> -> tensor<64x32xf16> tt.return } diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index c5d7ec8b65..94e08dde53 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -225,8 +225,8 @@ tt.func @inline_asm_scalar(%0: i32) { // CHECK-LABEL: reshape tt.func @reshape(%0: tensor<512xi32>) { - // CHECK: tt.reshape %{{.+}} {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32> - %1 = tt.reshape %0 {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32> + %1 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32> tt.return } diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index ecee359cb1..9422bb0f85 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: @test_canonicalize_convert_view // CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 // CHECK-NOT: triton_gpu.convert_layout -// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] {allow_reorder = true} +// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder // CHECK: tt.return %[[V]] #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> @@ -13,7 +13,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { %c = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> - %r = tt.reshape %c {allow_reorder = true} : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> + %r = tt.reshape %c allow_reorder : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } } // end module @@ -25,7 +25,7 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> // CHECK-LABEL: @test_canonicalize_convert_expensive_view // CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]] -// CHECK: %[[V:.+]] = tt.reshape %[[C]] {allow_reorder = true} +// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder // CHECK: tt.return %[[V]] #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> @@ -33,7 +33,7 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { %c = triton_gpu.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2> - %r = tt.reshape %c {allow_reorder = true} : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> + %r = tt.reshape %c allow_reorder : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } } // end module diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 601b0cc44b..78c6f68bf6 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2097,7 +2097,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.func public @reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked3> { // CHECK-NOT: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = triton_gpu.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> tt.return %c : tensor<32xf32, #blocked3> } @@ -2116,7 +2116,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: tt.reshape // CHECK: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> tt.return %b : tensor<32xf32, #blocked2> } } @@ -2133,7 +2133,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-NOT: triton_gpu.convert_layout // CHECK: arith.truncf // CHECK: triton_gpu.convert_layout - %a = tt.reshape %arg0 {allow_reorder = true, efficient_layout} : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> + %a = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> %b = triton_gpu.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = arith.truncf %b : tensor<32xf32, #blocked2> to tensor<32xf16, #blocked2> tt.return %c : tensor<32xf16, #blocked2> @@ -2536,9 +2536,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> %3 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<1x2x128xi32, #blocked2> - %4 = tt.reshape %3 {allow_reorder = false} : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1> + %4 = tt.reshape %3 : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1> %5 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x64xi32, #blocked2> - %6 = tt.reshape %5 {allow_reorder = false} : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1> + %6 = tt.reshape %5 : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1> %7 = arith.cmpi ne, %4, %cst : tensor<1x256xi32, #blocked1> %8 = arith.select %7, %6, %cst : tensor<1x256xi1, #blocked1>, tensor<1x256xi32, #blocked1> %9 = triton_gpu.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked> diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 28c815febb..7fa7812c5a 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -221,7 +221,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1>) : i32 { %39 = tt.load %arg9 : tensor<1x512x!tt.ptr, #blocked2> %40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr, #blocked1> - %41 = tt.reshape %39 {allow_reorder = true} : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> + %41 = tt.reshape %39 allow_reorder : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> %43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 1c8deb67ce..fdb1764174 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1341,7 +1341,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %84 = arith.sitofp %82 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked> %85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3> %86 = tt.trans %85 {order = array} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4> - %87 = tt.reshape %86 {allow_reorder = false} : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> + %87 = tt.reshape %86 allow_reorder : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> %88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> diff --git a/test/TritonGPU/optimize-locality.mlir b/test/TritonGPU/optimize-locality.mlir index 5073f997d4..5442998671 100644 --- a/test/TritonGPU/optimize-locality.mlir +++ b/test/TritonGPU/optimize-locality.mlir @@ -4,7 +4,7 @@ // CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0.000000e+00> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.addf // CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] @@ -207,7 +207,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0xFF800000> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.maximumf // CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] @@ -314,7 +314,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[CST:.*]] = arith.constant dense<0x7F800000> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.minimumf // CHECK: arith.minimumf %[[FOR_ARG]], %[[REDUCE]] @@ -421,7 +421,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.mulf // CHECK: arith.mulf %[[FOR_ARG]], %[[REDUCE]] @@ -579,14 +579,14 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-DAG: #[[$BLOCK1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> // CHECK-DAG: #[[$BLOCK2:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> // CHECK-LABEL: optimize_view_layout -// CHECK: %[[R:.+]] = tt.reshape {{.*}} {allow_reorder = true, efficient_layout} : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> +// CHECK: %[[R:.+]] = tt.reshape {{.*}} allow_reorder efficient_layout : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK2]]> -> tensor<64x16xf32, #[[$BLOCK1]]> // CHECK: "tt.reduce"(%[[C]]) #blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @optimize_view_layout(%arg0: tensor<8x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> { - %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> + %0 = tt.reshape %arg0 allow_reorder : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %2 = arith.maximumf %arg1, %arg2 : f32 diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 5e2b8a3b54..64f3193653 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -2075,7 +2075,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.func public @reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked3> { // CHECK-NOT: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = triton_gpu.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> tt.return %c : tensor<32xf32, #blocked3> } @@ -2094,7 +2094,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: tt.reshape // CHECK: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> tt.return %b : tensor<32xf32, #blocked2> } } @@ -2111,7 +2111,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-NOT: triton_gpu.convert_layout // CHECK: arith.truncf // CHECK: triton_gpu.convert_layout - %a = tt.reshape %arg0 {allow_reorder = true, efficient_layout} : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> + %a = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> %b = triton_gpu.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = arith.truncf %b : tensor<32xf32, #blocked2> to tensor<32xf16, #blocked2> tt.return %c : tensor<32xf16, #blocked2> From 9e60885179af5b8e07c069891a7d5d34cf5de65c Mon Sep 17 00:00:00 2001 From: victor-eds Date: Wed, 16 Oct 2024 11:21:38 +0100 Subject: [PATCH 2/2] Revert semantic changes to tests and add new ones --- test/Triton/invalid.mlir | 2 +- test/Triton/ops.mlir | 21 +++++++++++++++++++++ test/TritonGPU/loop-pipeline.mlir | 2 +- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index c35395ae20..a3826dded0 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -38,7 +38,7 @@ tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) { tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) { // expected-error @+1 {{number of src and dst elements of reshape must be the same}} - %a = tt.reshape %arg0 allow_reorder : tensor<32x128xf16> -> tensor<64x32xf16> + %a = tt.reshape %arg0 : tensor<32x128xf16> -> tensor<64x32xf16> tt.return } diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index 94e08dde53..6079eac538 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -225,11 +225,32 @@ tt.func @inline_asm_scalar(%0: i32) { // CHECK-LABEL: reshape tt.func @reshape(%0: tensor<512xi32>) { + // CHECK: tt.reshape %{{.+}} : tensor<512xi32> -> tensor<16x32xi32> + %1 = tt.reshape %0 : tensor<512xi32> -> tensor<16x32xi32> + tt.return +} + +// CHECK-LABEL: reshape_allow_reorder +tt.func @reshape_allow_reorder(%0: tensor<512xi32>) { // CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32> %1 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32> tt.return } +// CHECK-LABEL: reshape_efficient_layout +tt.func @reshape_efficient_layout(%0: tensor<512xi32>) { + // CHECK: tt.reshape %{{.+}} efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + %1 = tt.reshape %0 efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + tt.return +} + +// CHECK-LABEL: reshape_allow_reorder_efficient_layout +tt.func @reshape_allow_reorder_efficient_layout(%0: tensor<512xi32>) { + // CHECK: tt.reshape %{{.+}} allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + %1 = tt.reshape %0 allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + tt.return +} + // CHECK-LABEL: histogram tt.func @histogram(%0: tensor<512xi32>) { // CHECK: tt.histogram %{{.+}} : tensor<512xi32> -> tensor<16xi32> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index fdb1764174..97d85fcf1c 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1341,7 +1341,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %84 = arith.sitofp %82 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked> %85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3> %86 = tt.trans %85 {order = array} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4> - %87 = tt.reshape %86 allow_reorder : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> + %87 = tt.reshape %86 : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> %88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma>