Skip to content

Commit 6e390f3

Browse files
authored
[BACKEND] Restrict pattern re-ordering of alloc and reshape (#8266)
This avoids extra round trip to register when the user needs mma layout.
1 parent d5d1c6d commit 6e390f3

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,12 @@ class ReshapeMemDesc : public OpRewritePattern<LocalAllocOp> {
165165
getContext(), allocOp.getLoc(), allocType, srcShape, innerTy)))
166166
return failure();
167167

168+
// For now don't apply the transformation if the new encoding is not an
169+
// MMAv3/v5 encoding as it may not be compatible with the user.
170+
// The heuristic can be refined once we have more flexible mma ops.
171+
if (!isa<NVMMASharedEncodingAttr>(innerTy.getEncoding()))
172+
return failure();
173+
168174
auto newAlloc = rewriter.create<LocalAllocOp>(allocOp.getLoc(), innerTy,
169175
reshapeOp.getSrc());
170176
rewriter.replaceOpWithNewOp<MemDescReshapeOp>(allocOp, allocOp.getType(),

test/TritonGPU/dot-operands.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,20 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num
276276
tt.return %a: !ttg.memdesc<128x64xf16, #shared, #smem>
277277
}
278278
}
279+
280+
// -----
281+
282+
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 32], warpsPerCTA = [1, 2, 2, 1], order = [3, 2, 1, 0]}>
283+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
284+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}>
285+
#smem = #ttg.shared_memory
286+
module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
287+
// CHECK-LABEL: @reshape_memedesc_negative
288+
tt.func @reshape_memedesc_negative(%arg: tensor<1x32x2x64xf32, #blocked>) -> !ttg.memdesc<64x64xf32, #shared, #smem> {
289+
%r = tt.reshape %arg : tensor<1x32x2x64xf32, #blocked> -> tensor<64x64xf32, #blocked1>
290+
// CHECK-NOT: ttg.memdesc_reshape
291+
%a = ttg.local_alloc %r : (tensor<64x64xf32, #blocked1>) -> !ttg.memdesc<64x64xf32, #shared, #smem>
292+
// CHECK: tt.return
293+
tt.return %a: !ttg.memdesc<64x64xf32, #shared, #smem>
294+
}
295+
}

0 commit comments

Comments
 (0)