Skip to content

Commit cbd5d48

Browse files
authored
[AMD] Drop deprecated pattern in OptimizeDotOperands pass (#8716)
1 parent c3c65b9 commit cbd5d48

File tree

2 files changed

+1
-306
lines changed

2 files changed

+1
-306
lines changed

test/TritonGPU/amd/amd-optimize-dot-operands.mlir

Lines changed: 1 addition & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file -tritonamdgpu-optimize-dot-operands="arch-generation-name=gfx950" | FileCheck %s --check-prefixes CHECK,GFX950
2-
// RUN: triton-opt %s -split-input-file -tritonamdgpu-optimize-dot-operands="arch-generation-name=gfx942" | FileCheck %s
3-
4-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
5-
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 0], [0, 0]], block = []}>
6-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32, 16], isTransposed = true}>
7-
#mma1 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [16, 16, 32], isTransposed = true}>
8-
// CHECK{LITERAL}: #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [0, 1]}>
9-
// CHECK{LITERAL}: #smem = #ttg.shared_memory
10-
// CHECK-LABEL: test_local_load_transposed
11-
// CHECK: %[[LOAD:.+]] = tt.load {{.*}} : tensor<64x16x!tt.ptr<f16>, #blocked>
12-
// CHECK: %[[ALLOC:.+]] = ttg.local_alloc %[[LOAD]] : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem>
13-
// CHECK: %[[LOCAL_LOAD_TRANS:.+]] = ttg.local_load %[[ALLOC]] : !ttg.memdesc<64x16xf16, #shared, #smem> -> tensor<64x16xf16, #linear>
14-
// CHECK: %[[LOCAL_LOAD_DIRECT:.+]] = ttg.local_load %[[ALLOC]] : !ttg.memdesc<64x16xf16, #shared, #smem> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
15-
// CHECK: tt.dot {{.+}}, %[[LOCAL_LOAD_DIRECT]], {{.+}}: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x16xf32, #mma>
16-
// CHECK: %[[TRANS:.+]] = tt.trans %[[LOCAL_LOAD_TRANS]] {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>>
17-
// CHECK: tt.dot {{.+}}, %[[TRANS]], {{.+}} : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>> -> tensor<128x64xf32, #mma1>
18-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
19-
tt.func public @test_local_load_transposed(
20-
%arg0: tensor<64x16x!tt.ptr<f16>, #blocked>,
21-
%out0 : tensor<128x16x!tt.ptr<f32>, #blocked>,
22-
%out1 : tensor<128x64x!tt.ptr<f32>, #blocked>
23-
) {
24-
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
25-
%cst_1 = arith.constant dense<0.693147182> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 8}>>
26-
%cst_2 = arith.constant dense<0.581374812> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
27-
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
28-
29-
%0 = tt.load %arg0 : tensor<64x16x!tt.ptr<f16>, #blocked>
30-
%1 = ttg.convert_layout %0 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #linear>
31-
%2 = ttg.convert_layout %0 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>>
32-
%3 = tt.dot %cst_1, %2, %cst_0 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 8}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>> -> tensor<128x16xf32, #mma1>
33-
%4 = tt.trans %1 {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
34-
%5 = tt.dot %cst_2, %4, %cst_3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x64xf32, #mma>
35-
36-
%6 = ttg.convert_layout %3 : tensor<128x16xf32, #mma1> -> tensor<128x16xf32, #blocked>
37-
%7 = ttg.convert_layout %5 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked>
38-
tt.store %out0, %6 : tensor<128x16x!tt.ptr<f32>, #blocked>
39-
tt.store %out1, %7 : tensor<128x64x!tt.ptr<f32>, #blocked>
40-
tt.return
41-
}
42-
}
43-
// -----
44-
45-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
46-
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 0], [0, 0]], block = []}>
47-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32, 16], isTransposed = true}>
48-
#mma1 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [16, 16, 32], isTransposed = true}>
49-
// CHECK-NOT: #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [0, 1]}>
50-
// CHECK-NOT: #smem = #ttg.shared_memory
51-
// CHECK-LABEL: test_not_local_load_transposed_kWidth_mismatch
52-
// CHECK: tt.load {{.*}} : tensor<64x16x!tt.ptr<f16>, #blocked>
53-
// CHECK-NOT: ttg.local_alloc
54-
// CHECK-NOT: ttg.local_load
55-
// CHECK-NOT: ttg.local_load
56-
// CHECK: tt.dot {{.+}}: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x16xf32, #mma>
57-
// CHECK: tt.trans {{.+}} {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>>
58-
// CHECK: tt.dot {{.+}} : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>> -> tensor<128x64xf32, #mma1>
59-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
60-
tt.func public @test_not_local_load_transposed_kWidth_mismatch(
61-
%arg0: tensor<64x16x!tt.ptr<f16>, #blocked>,
62-
%out0 : tensor<128x16x!tt.ptr<f32>, #blocked>,
63-
%out1 : tensor<128x64x!tt.ptr<f32>, #blocked>
64-
) {
65-
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
66-
%cst_1 = arith.constant dense<0.693147182> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
67-
%cst_2 = arith.constant dense<0.581374812> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
68-
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
69-
70-
%0 = tt.load %arg0 : tensor<64x16x!tt.ptr<f16>, #blocked>
71-
%1 = ttg.convert_layout %0 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #linear>
72-
%2 = ttg.convert_layout %0 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
73-
%3 = tt.dot %cst_1, %2, %cst_0 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1>
74-
%4 = tt.trans %1 {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
75-
%5 = tt.dot %cst_2, %4, %cst_3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x64xf32, #mma>
76-
77-
%6 = ttg.convert_layout %3 : tensor<128x16xf32, #mma1> -> tensor<128x16xf32, #blocked>
78-
%7 = ttg.convert_layout %5 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked>
79-
tt.store %out0, %6 : tensor<128x16x!tt.ptr<f32>, #blocked>
80-
tt.store %out1, %7 : tensor<128x64x!tt.ptr<f32>, #blocked>
81-
tt.return
82-
}
83-
}
84-
// -----
85-
86-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
87-
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 0], [0, 0]], block = []}>
88-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32, 16], isTransposed = true}>
89-
#mma1 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [16, 16, 32], isTransposed = true}>
90-
// CHECK-NOT: #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [0, 1]}>
91-
// CHECK-NOT: #smem = #ttg.shared_memory
92-
// CHECK-LABEL: test_not_local_load_transposed_opIdx_mismatch
93-
// CHECK: tt.load {{.*}} : tensor<64x16x!tt.ptr<f16>, #blocked>
94-
// CHECK-NOT: ttg.local_alloc
95-
// CHECK-NOT: ttg.local_load
96-
// CHECK-NOT: ttg.local_load
97-
// CHECK: tt.dot {{.+}}: tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<64x64xf32, #mma>
98-
// CHECK: tt.trans {{.+}} {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>>
99-
// CHECK: tt.dot {{.+}} : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>> -> tensor<128x64xf32, #mma1>
100-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
101-
tt.func public @test_not_local_load_transposed_opIdx_mismatch(
102-
%arg0: tensor<64x16x!tt.ptr<f16>, #blocked>,
103-
%out0 : tensor<64x64x!tt.ptr<f32>, #blocked>,
104-
%out1 : tensor<128x64x!tt.ptr<f32>, #blocked>
105-
) {
106-
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma1>
107-
%cst_1 = arith.constant dense<0.693147182> : tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>>
108-
%cst_2 = arith.constant dense<0.581374812> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
109-
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
110-
111-
%0 = tt.load %arg0 : tensor<64x16x!tt.ptr<f16>, #blocked>
112-
%1 = ttg.convert_layout %0 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #linear>
113-
%2 = ttg.convert_layout %0 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 8}>>
114-
%3 = tt.dot %2, %cst_1, %cst_0 : tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 8}>> -> tensor<64x64xf32, #mma1>
115-
%4 = tt.trans %1 {order = array<i32: 1, 0>} : tensor<64x16xf16, #linear> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
116-
%5 = tt.dot %cst_2, %4, %cst_3 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<128x64xf32, #mma>
117-
118-
%6 = ttg.convert_layout %3 : tensor<64x64xf32, #mma1> -> tensor<64x64xf32, #blocked>
119-
%7 = ttg.convert_layout %5 : tensor<128x64xf32, #mma> -> tensor<128x64xf32, #blocked>
120-
tt.store %out0, %6 : tensor<64x64x!tt.ptr<f32>, #blocked>
121-
tt.store %out1, %7 : tensor<128x64x!tt.ptr<f32>, #blocked>
122-
tt.return
123-
}
124-
}
125-
126-
// -----
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-optimize-dot-operands="arch-generation-name=gfx950" | FileCheck %s --check-prefixes GFX950
1272

1283
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
1294
#linear = #ttg.linear<{register = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [1, 0, 0], [2, 0, 0], [0, 32, 0], [0, 64, 0]], lane = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0], [0, 0, 8], [0, 0, 16]], warp = [[0, 16, 0]], block = []}>

third_party/amd/lib/TritonAMDGPUTransforms/OptimizeDotOperands.cpp

Lines changed: 0 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -26,185 +26,6 @@ namespace mlir::triton::amdgpu {
2626

2727
namespace {
2828

29-
// Detect a pair of tt.dot ops that both use the same tt.load result, one
30-
// directly and one via tt.trans and use the same shared memory buffer in this
31-
// case. Given:
32-
// load -> cvt -> .. -> dot1
33-
// -> cvt -> .. -> trans -> cvt -> .. -> dot2
34-
// Rewrite to:
35-
// load -> local_alloc -> local_load -> dot1
36-
// -> local_load_transposed -> dot2
37-
class ReuseShmemForDirectAndTransposedUse : public OpRewritePattern<LoadOp> {
38-
public:
39-
ReuseShmemForDirectAndTransposedUse(MLIRContext *context,
40-
triton::AMD::ISAFamily isaFamily)
41-
: OpRewritePattern(context), isaFamily(isaFamily) {}
42-
43-
LogicalResult matchAndRewrite(tt::LoadOp loadOp,
44-
PatternRewriter &rewriter) const override {
45-
auto numUsers = llvm::range_size(loadOp->getUsers());
46-
if (numUsers < 2) {
47-
return rewriter.notifyMatchFailure(loadOp,
48-
"load op must have at least 2 users");
49-
}
50-
51-
auto srcTy = dyn_cast<RankedTensorType>(loadOp.getType());
52-
if (!srcTy) {
53-
return rewriter.notifyMatchFailure(loadOp, "src type must be a tensor");
54-
}
55-
56-
LDBG("ReuseShmemForDirectAndTransposedUse for load Op: " << *loadOp);
57-
58-
tt::DotOpInterface directDot = nullptr;
59-
tt::DotOpInterface transDot = nullptr;
60-
ttg::ConvertLayoutOp cvtOp = nullptr;
61-
unsigned directOpIdx = 0;
62-
unsigned transOpIdx = 0;
63-
64-
auto followConvertLayoutChain =
65-
[](mlir::Value &usedValue, mlir::Operation *op) -> mlir::Operation * {
66-
while (isa<ttg::ConvertLayoutOp>(op)) {
67-
// Ensure we have exactly one user
68-
if (!(op->hasOneUse()))
69-
return nullptr;
70-
usedValue = op->getResult(0);
71-
op = *(op->getUsers().begin());
72-
}
73-
74-
return op;
75-
};
76-
77-
mlir::Value usedValue;
78-
for (mlir::Operation *user : loadOp->getUsers()) {
79-
auto op = user;
80-
81-
op = followConvertLayoutChain(usedValue, op);
82-
83-
if (auto transOp = dyn_cast_or_null<tt::TransOp>(op)) {
84-
LDBG("Found tranpose op: " << *transOp);
85-
cvtOp = transOp.getSrc().getDefiningOp<ttg::ConvertLayoutOp>();
86-
LDBG("Found parent cvt op of transpose: " << *cvtOp);
87-
usedValue = transOp->getResult(0);
88-
op =
89-
followConvertLayoutChain(usedValue, *(transOp->getUsers().begin()));
90-
if (auto dotOp = dyn_cast<tt::DotOpInterface>(op)) {
91-
transDot = dotOp;
92-
transOpIdx = (usedValue == dotOp.getA()) ? 0 : 1;
93-
}
94-
} else if (auto dotOp = dyn_cast_or_null<tt::DotOpInterface>(op)) {
95-
directDot = dotOp;
96-
directOpIdx = (usedValue == dotOp.getA()) ? 0 : 1;
97-
}
98-
99-
if (directDot && transDot)
100-
break;
101-
}
102-
103-
if (!directDot)
104-
return rewriter.notifyMatchFailure(loadOp,
105-
"expected a direct tt.dot user");
106-
if (!transDot)
107-
return rewriter.notifyMatchFailure(
108-
loadOp, "expected a tt.trans feeding a tt.dot user");
109-
if (directOpIdx != transOpIdx) {
110-
return rewriter.notifyMatchFailure(loadOp, [&](mlir::Diagnostic &d) {
111-
d << "operand indices of direct and transposed tt.dot users must be "
112-
"the same. Got indices: direct: "
113-
<< directOpIdx << " and transposed: " << transOpIdx;
114-
});
115-
}
116-
117-
LDBG("load is shared between transposed and non-transposed users");
118-
LDBG("Non-transposed access tt.dot: " << *directDot);
119-
LDBG("Transposed access tt.dot: " << *transDot);
120-
121-
unsigned opIdx = directOpIdx;
122-
123-
auto directOperandType =
124-
cast<RankedTensorType>(directDot->getOperand(opIdx).getType());
125-
auto transOperandType =
126-
cast<RankedTensorType>(transDot->getOperand(opIdx).getType());
127-
auto directDotEnc =
128-
dyn_cast<ttg::DotOperandEncodingAttr>(directOperandType.getEncoding());
129-
auto transDotEnc =
130-
dyn_cast<ttg::DotOperandEncodingAttr>(transOperandType.getEncoding());
131-
132-
if (!directDotEnc || !transDotEnc) {
133-
return rewriter.notifyMatchFailure(loadOp,
134-
"wrong encodings for tt.dot users");
135-
}
136-
137-
if (directDotEnc.getKWidth() != transDotEnc.getKWidth()) {
138-
return rewriter.notifyMatchFailure(loadOp, [&](mlir::Diagnostic &d) {
139-
d << "kWidths are mismatching. direct: " << directDotEnc.getKWidth()
140-
<< " and transposed: " << transDotEnc.getKWidth();
141-
});
142-
}
143-
144-
// We need to ensure that the parents of direct and transposed dot encodings
145-
// are matching in order to get the same shared memory encoding. Note that
146-
// they can have different instrShape(s) (mfma instructions) but still map
147-
// to the same shared memory encoding.
148-
auto directCTALayout = ttg::getCTALayout(directDotEnc);
149-
auto transCTALayout = ttg::getCTALayout(transDotEnc);
150-
151-
if (directCTALayout != transCTALayout) {
152-
return rewriter.notifyMatchFailure(
153-
loadOp,
154-
"CTA layouts of direct and transposed tt.dot users are mismatching");
155-
}
156-
157-
auto ctx = getContext();
158-
auto sharedOrder = ttg::getOrderForMemory(srcTy);
159-
auto sharedEnc = ttg::SwizzledSharedEncodingAttr::get(
160-
ctx, directDotEnc, directOperandType.getShape(), sharedOrder,
161-
directCTALayout, directOperandType.getElementType(),
162-
/*needTrans=*/false);
163-
164-
LDBG("Created shared encoding: " << sharedEnc);
165-
rewriter.setInsertionPointAfter(loadOp);
166-
auto sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(ctx);
167-
Location loc = loadOp.getLoc();
168-
auto alloc = ttg::LocalAllocOp::create(
169-
rewriter, loc,
170-
ttg::MemDescType::get(srcTy.getShape(), srcTy.getElementType(),
171-
sharedEnc, sharedMemorySpace),
172-
loadOp.getResult());
173-
LDBG("Created local alloc op: " << *alloc);
174-
auto localLoad =
175-
ttg::LocalLoadOp::create(rewriter, loc, directOperandType, alloc);
176-
LDBG("Created local load op:" << *localLoad);
177-
rewriter.modifyOpInPlace(
178-
directDot, [&]() { directDot->setOperand(opIdx, localLoad); });
179-
LDBG("Updated Direct dot: " << *directDot);
180-
if (!canUseLocalLoadTransposed(opIdx, sharedOrder)) {
181-
rewriter.modifyOpInPlace(cvtOp, [&]() {
182-
cvtOp.getSrcMutable().assign(localLoad.getResult());
183-
});
184-
LDBG("Updated cvt op: " << *cvtOp);
185-
} else {
186-
return rewriter.notifyMatchFailure(loadOp, "currently not supported");
187-
}
188-
189-
LDBG("Updated Trans dot: " << *transDot);
190-
191-
return success();
192-
}
193-
194-
private:
195-
bool canUseLocalLoadTransposed(unsigned opIdx,
196-
ArrayRef<unsigned> sharedOrder) const {
197-
// TODO(PMylon): Comment out for now, until lowering from
198-
// local_load_transposed to ds_read_tr is supported.
199-
// unsigned kDimIdx = (opIdx == 0) ? 1 : 0;
200-
// bool isCDNA4 = (isaFamily == triton::AMD::ISAFamily::CDNA4);
201-
// bool isKContig = (sharedOrder[0] == kDimIdx);
202-
return false;
203-
}
204-
205-
triton::AMD::ISAFamily isaFamily;
206-
};
207-
20829
// This pattern creates LocalAllocOp and LocalLoadOp with unswizzled shared
20930
// layout for the scale operand used in ScaledUpcastFp4Op/ScaledUpcastFp8Op.
21031
// StreamPipeliner will respect the layout created here and pipeline ops
@@ -304,7 +125,6 @@ class TritonAMDGPUOptimizeDotOperands
304125

305126
mlir::RewritePatternSet patterns(context);
306127
auto isaFamily = triton::AMD::deduceISAFamily(archGenerationName);
307-
patterns.add<ReuseShmemForDirectAndTransposedUse>(context, isaFamily);
308128
patterns
309129
.add<AllocSharedMemForUpcastedScales<tt::amdgpu::ScaledUpcastFp8Op>,
310130
AllocSharedMemForUpcastedScales<tt::amdgpu::ScaledUpcastFp4Op>>(

0 commit comments

Comments
 (0)