Skip to content

Commit 1064b59

Browse files
authored
[BACKEND] Propagate mma layout to following elementwise operations. (#3973)
For matmul with following arithmetic operations such as `acc += tl.dot(a, b)`, currently the mma layout of the `dot` result isn't propagated into the subsequent `add`. As a result when the dot is inside a loop, there will be repeated layout conversion from mma to blocked. I'm fixing this by allowing mma layout propagated so that it can be reused.
1 parent ed39cb0 commit 1064b59

File tree

3 files changed

+45
-107
lines changed

3 files changed

+45
-107
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 3 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -163,85 +163,6 @@ void LayoutRematerialization::cleanup() {
163163
op->erase();
164164
}
165165

166-
// Look ahead to at the transitive uses and see if there is a convert to mma
167-
// operations.
168-
bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
169-
SmallVector<Value> queue = {op->getResult(0)};
170-
SetVector<Operation *> forwardSlice;
171-
llvm::SmallDenseSet<Value> seen;
172-
while (!queue.empty()) {
173-
Value currentValue = queue.back();
174-
queue.pop_back();
175-
getForwardSlice(currentValue, &forwardSlice);
176-
for (Operation *op : forwardSlice) {
177-
// HACK: Stop propagation if the ReduceOp is using mma layout but is
178-
// producing tensor smaller than the layout we would like to propagate.
179-
// This is to avoid stepping into the known bug.
180-
if (isa<mlir::triton::ReduceOp>(op)) {
181-
auto tensorType =
182-
dyn_cast<RankedTensorType>(op->getOperand(0).getType());
183-
if (tensorType &&
184-
isa<NvidiaMmaEncodingAttr>(tensorType.getEncoding())) {
185-
auto mmaInstrShape =
186-
cast<NvidiaMmaEncodingAttr>(encoding).getInstrShape();
187-
if (tensorType.getShape()[tensorType.getRank() - 2] <
188-
mmaInstrShape[0] ||
189-
tensorType.getShape()[tensorType.getRank() - 1] <
190-
mmaInstrShape[1]) {
191-
return false;
192-
}
193-
}
194-
}
195-
196-
if (auto convertOp = dyn_cast<ConvertLayoutOp>(op)) {
197-
Attribute dstEncoding = convertOp.getType().getEncoding();
198-
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(dstEncoding))
199-
return (mmaLayout.getVersionMajor() > 1) ? true
200-
: mmaLayout == encoding;
201-
if (isa<triton::gpu::AMDMfmaEncodingAttr,
202-
triton::gpu::AMDWmmaEncodingAttr>(dstEncoding))
203-
return true;
204-
if (isa<triton::gpu::DotOperandEncodingAttr>(dstEncoding)) {
205-
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(encoding)) {
206-
return mmaLayout.getVersionMajor() > 1;
207-
} else {
208-
assert((mlir::isa<triton::gpu::AMDMfmaEncodingAttr,
209-
triton::gpu::AMDWmmaEncodingAttr>(encoding)));
210-
return true;
211-
}
212-
}
213-
}
214-
bool isMMAV3 =
215-
isa<NvidiaMmaEncodingAttr>(encoding) &&
216-
cast<NvidiaMmaEncodingAttr>(encoding).getVersionMajor() == 3;
217-
if (isMMAV3 && (isa<LocalAllocOp>(op) || isa<LocalStoreOp>(op)))
218-
return true;
219-
auto yield = dyn_cast<scf::YieldOp>(op);
220-
if (!yield)
221-
continue;
222-
if (auto ifOp = dyn_cast<scf::IfOp>(yield->getParentOp())) {
223-
for (OpOperand &operand : yield->getOpOperands()) {
224-
Operation *def = operand.get().getDefiningOp();
225-
if (def &&
226-
(forwardSlice.count(def) || operand.get() == currentValue) &&
227-
(seen.insert(operand.get()).second == true))
228-
queue.push_back(ifOp.getResult(operand.getOperandNumber()));
229-
}
230-
}
231-
auto forOp = dyn_cast<scf::ForOp>(yield.getOperation()->getParentOp());
232-
if (!forOp)
233-
continue;
234-
for (OpOperand &operand : yield->getOpOperands()) {
235-
Operation *def = operand.get().getDefiningOp();
236-
if (def && (forwardSlice.count(def) || operand.get() == currentValue) &&
237-
(seen.insert(operand.get()).second == true))
238-
queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber()));
239-
}
240-
}
241-
}
242-
return false;
243-
}
244-
245166
// Return true if the op is an op with a layout we don't want to change. We will
246167
// propagate the layout starting from anchor ops.
247168
bool isLayoutAnchor(Operation *op) {
@@ -262,18 +183,8 @@ bool isLayoutAnchor(Operation *op) {
262183
}
263184

264185
void LayoutPropagation::initAnchorLayout() {
265-
auto maybeAddAnchor = [&](Value v) {
186+
auto addAnchor = [&](Value v) {
266187
if (auto tensorType = dyn_cast<RankedTensorType>(v.getType())) {
267-
// Workaround, don't popagate MMA layout unless there is a convert
268-
// back to mma further down to avoid generating reduction with MMA
269-
// layout that may have lower performance.
270-
// This can be improved with more aggressive backward propagation.
271-
if (isa<MmaEncodingTrait>(tensorType.getEncoding()) &&
272-
v.getDefiningOp() &&
273-
!hasConvertToMMATransisitiveUse(v.getDefiningOp(),
274-
tensorType.getEncoding())) {
275-
return;
276-
}
277188
layouts.insert({v, LayoutInfo(tensorType.getEncoding())});
278189
}
279190
};
@@ -282,13 +193,13 @@ void LayoutPropagation::initAnchorLayout() {
282193
// you can pass a tensor with an encoding as an arg, instead of explicitly
283194
// calling tt.load.
284195
for (auto arg : funcOp.getArguments()) {
285-
maybeAddAnchor(arg);
196+
addAnchor(arg);
286197
}
287198

288199
funcOp.walk([&](Operation *op) {
289200
if (isLayoutAnchor(op)) {
290201
for (auto result : op->getResults()) {
291-
maybeAddAnchor(result);
202+
addAnchor(result);
292203
}
293204
}
294205
});

python/test/unit/language/test_core.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3222,21 +3222,6 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
32223222
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri,
32233223
w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs)
32243224

3225-
if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"):
3226-
if not is_cuda():
3227-
pass
3228-
else:
3229-
ptx = pgm.asm["ptx"]
3230-
start = ptx.find("shfl.sync.bfly")
3231-
end = ptx.find("cvt.rn.f16.f32")
3232-
red_code = ptx[start:end]
3233-
assert len(red_code) > 0
3234-
3235-
# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
3236-
# TODO: we should eliminate these unused functions in ptx code.
3237-
if not (capability[0] >= 9):
3238-
assert "shared" not in red_code
3239-
assert "bar.sync" not in red_code
32403225
# torch result
32413226
if in_dtype == 'int8':
32423227
z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32)

test/TritonGPU/combine.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2607,3 +2607,45 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
26072607
tt.return %outLHS : tensor<128x64xf32, #blocked1>
26082608
}
26092609
}
2610+
2611+
// -----
2612+
2613+
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
2614+
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
2615+
#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
2616+
#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
2617+
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>
2618+
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>
2619+
2620+
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} {
2621+
// CHECK-LABEL: matmul_add
2622+
tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>, %C : !tt.ptr<f32>) {
2623+
%a_ptr_init = tt.splat %A : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #AL>
2624+
%b_ptr_init = tt.splat %B : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #BL>
2625+
%c_ptr_init = tt.splat %C : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #CL>
2626+
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #CL>
2627+
%cst = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
2628+
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
2629+
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
2630+
2631+
%100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #CL>) {
2632+
%a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr<f16>, #AL>
2633+
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
2634+
%b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr<f16>, #BL>
2635+
%b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT>
2636+
%c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
2637+
%t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL>
2638+
// CHECK: %[[T0:.*]] = tt.dot
2639+
// CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma>
2640+
%t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL>
2641+
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
2642+
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
2643+
// CHECK: scf.yield
2644+
scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #CL>
2645+
}
2646+
2647+
// CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked
2648+
tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr<f32>, #CL>
2649+
tt.return
2650+
}
2651+
}

0 commit comments

Comments
 (0)