Skip to content

Commit bbba43a

Browse files
Merge commit 'd2078949e8d34b3ed9d486fc8ddd7d30329b8f6c'
2 parents a626ab8 + d207894 commit bbba43a

File tree

6 files changed

+154
-274
lines changed

6 files changed

+154
-274
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ python/triton/language/extra
3131
# Proton
3232
python/triton/profiler
3333

34+
# Instrumentation
35+
python/triton/instrumentation
36+
3437
# Python caches
3538
__pycache__/
3639
*.py[cod]

lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ class TF32x3 : public OpRewritePattern<DotOp> {
4545
ArrayRef<Value>{value})
4646
.getResult()[0];
4747
};
48+
auto zeroLike = [&](Value c) -> Value {
49+
return rewriter.create<SplatOp>(
50+
dotOp->getLoc(), c.getType(),
51+
rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
52+
rewriter.getF32FloatAttr(0)));
53+
};
54+
auto add = [&](Value a, Value b) -> Value {
55+
return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b);
56+
};
4857
auto sub = [&](Value a, Value b) -> Value {
4958
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
5059
};
@@ -60,11 +69,15 @@ class TF32x3 : public OpRewritePattern<DotOp> {
6069
auto bBig = f32ToTF32(dotOp.getB());
6170
auto bSmall = sub(dotOp.getB(), bBig);
6271

63-
auto dot1 = dot(aSmall, bBig, dotOp.getC());
72+
auto zero = zeroLike(dotOp.getC());
73+
74+
auto dot1 = dot(aSmall, bBig, zero);
6475
auto dot2 = dot(aBig, bSmall, dot1);
6576
auto dot3 = dot(aBig, bBig, dot2);
6677

67-
rewriter.replaceOp(dotOp, dot3);
78+
auto sum = add(dot3, dotOp.getC());
79+
80+
rewriter.replaceOp(dotOp, sum);
6881
return success();
6982
}
7083
};

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 91 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,100 @@
11
// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s
22

3-
// Check that we order load, local_alloc, local_store (optional) and local_load one after another. This is useful
4-
// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers
3+
// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands
4+
// in cases where local_alloc is in the loop but it's operand is not.
5+
// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers
56
// throughout the computation.
6-
// CHECK-LABEL: order_load_alloc_local_load
7-
// CHECK: %[[LOAD:.+]] = tt.load
8-
// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]]
9-
// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]]
10-
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
11-
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
12-
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
13-
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
14-
tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked>) attributes {noinline = false} {
15-
%9 = tt.load %arg0 : tensor<32x32x!tt.ptr<f32>, #blocked>
16-
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
17-
%10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared>
18-
%cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
19-
%11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
20-
%12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
21-
%13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked>
22-
tt.store %arg0, %13 : tensor<32x32x!tt.ptr<f32>, #blocked>
7+
8+
// CHECK-LABEL: hoist_q_out_of_the_loop
9+
// CHECK: %[[TRUNCF:.+]] = arith.truncf
10+
// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]]
11+
// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]]
12+
// CHECK: scf.for
13+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
14+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
15+
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
16+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
17+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
18+
tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
19+
%c0_i32 = arith.constant 0 : i32
20+
%cst = arith.constant 1.44269502 : f32
21+
%c128_i32 = arith.constant 128 : i32
22+
%c128_i64 = arith.constant 128 : i64
23+
%c0_i64 = arith.constant 0 : i64
24+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
25+
%1 = tt.get_program_id y : i32
26+
%2 = arith.muli %1, %arg7 : i32
27+
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
28+
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
29+
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
30+
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
31+
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
32+
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
33+
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
34+
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
35+
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
36+
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
37+
%75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory>
38+
%76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
39+
%77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory>
40+
%78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
41+
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
42+
%107 = arith.addi %arg26, %c128_i64 : i64
43+
scf.yield %107 : i64
44+
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
2345
tt.return
2446
}
2547
}
48+
49+
50+
// -----
51+
// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both
52+
// local_alloc and it's src tensor defining op are in the loop.
53+
// CHECK-LABEL: no_hoist_q_type_reordering
54+
// CHECK: scf.for
55+
// CHECK: %[[TRUNCF:.+]] = arith.truncf
56+
// CHECK-NEXT: arith.constant
57+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
58+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
59+
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
60+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
61+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
62+
tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
63+
%c0_i32 = arith.constant 0 : i32
64+
%cst = arith.constant 1.44269502 : f32
65+
%c128_i32 = arith.constant 128 : i32
66+
%c128_i64 = arith.constant 128 : i64
67+
%c0_i64 = arith.constant 0 : i64
68+
%1 = tt.get_program_id y : i32
69+
%2 = arith.muli %1, %arg7 : i32
70+
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
71+
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
72+
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
73+
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
74+
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
75+
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
76+
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
77+
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
78+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
79+
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
80+
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
81+
%75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory>
82+
%76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
83+
%77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory>
84+
%78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
85+
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
86+
%107 = arith.addi %arg26, %c128_i64 : i64
87+
scf.yield %107 : i64
88+
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
89+
tt.return
90+
}
91+
}
92+
93+
// -----
94+
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
95+
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
96+
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
97+
2698
// CHECK-LABEL: order_load_alloc_local_load_local_store
2799
// CHECK: %[[LOAD:.+]] = tt.load
28100
// CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc

0 commit comments

Comments
 (0)