Skip to content

Commit e558838

Browse files
authored
[BACKEND] Use LL to simplify redundant elements check and fix related issues (#5225)
1 parent 4330372 commit e558838

File tree

2 files changed

+70
-104
lines changed

2 files changed

+70
-104
lines changed

python/test/unit/language/test_core.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5436,21 +5436,11 @@ def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
54365436
pytest.skip("Skip testing MMAv3 on devices with CC < 9")
54375437

54385438
num_warps = np.cumprod(src_layout.warps_per_cta)[-1]
5439-
# TODO(Keren): Remove the intermediate layout once we have resolved the redundantDataMask issue for WGMMA
5440-
warps_per_cta = src_layout.warps_per_cta
5441-
interm = BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [warps_per_cta[0], warps_per_cta[1]], [0, 1], [1, 1],
5442-
[1, 1], [0, 1])
54435439

54445440
def do_test(src_layout, dst_layout):
54455441
layouts = f"""
54465442
#src = {src_layout}
54475443
#dst = {dst_layout}
5448-
#interm = {interm}
5449-
"""
5450-
5451-
conversion = f"""
5452-
%12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
5453-
%13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
54545444
"""
54555445

54565446
ir = layouts + f"""
@@ -5460,6 +5450,7 @@ def do_test(src_layout, dst_layout):
54605450
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
54615451
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
54625452
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
5453+
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #dst>
54635454
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src>
54645455
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
54655456
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src>
@@ -5468,12 +5459,10 @@ def do_test(src_layout, dst_layout):
54685459
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
54695460
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
54705461
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<f16>, #src>
5471-
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #interm>
5472-
""" + conversion + f"""
5473-
%15 = triton_gpu.convert_layout %12 : tensor<{M}x{N}xi32, #dst> -> tensor<{M}x{N}xi32, #interm>
5474-
%16 = triton_gpu.convert_layout %13 : tensor<{M}x{N}xf16, #dst> -> tensor<{M}x{N}xf16, #interm>
5475-
%17 = tt.addptr %3, %15 : tensor<{M}x{N}x!tt.ptr<f16>, #interm>, tensor<{M}x{N}xi32, #interm>
5476-
tt.store %17, %16 : tensor<{M}x{N}x!tt.ptr<f16>, #interm>
5462+
%12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
5463+
%13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
5464+
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>, tensor<{M}x{N}xi32, #dst>
5465+
tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>
54775466
tt.return
54785467
}}
54795468
}}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 65 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "Utility.h"
99
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
1010
#include "triton/Dialect/Triton/IR/Dialect.h"
11+
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1112
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1213

1314
using namespace mlir;
@@ -24,87 +25,57 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
2425
namespace {
2526

2627
// Return the mask for the unique data accessed by given tensor type.
27-
// Used to mask out the redundant data accessed by threads.
28-
Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
29-
Location loc, const NVIDIA::TargetInfo &targetInfo) {
28+
// NOTE: Redundant memory load is allowed in triton, but redundant memory store
29+
// is not allowed.
30+
// mask = true: thread can write
31+
// mask = false: thread should not write
32+
Value getRedundantDataMask(ModuleOp moduleOp, Type valueTy,
33+
ConversionPatternRewriter &rewriter, Location loc,
34+
int regIdx, const NVIDIA::TargetInfo &targetInfo) {
35+
auto ctx = moduleOp.getContext();
3036
auto tensorTy = dyn_cast<RankedTensorType>(valueTy);
31-
Value mask = int_val(1, 1);
37+
auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp);
3238
auto tid = tid_val();
33-
auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc);
39+
auto mask = true_val();
40+
auto kReg = str_attr("register");
41+
auto kLane = str_attr("lane");
42+
auto kWarp = str_attr("warp");
43+
auto kBlock = str_attr("block");
3444
if (tensorTy) {
35-
auto layout = tensorTy.getEncoding();
3645
auto shape = tensorTy.getShape();
37-
unsigned rank = shape.size();
38-
auto sizePerThread = triton::gpu::getSizePerThread(layout);
39-
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
40-
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
41-
auto threadOrder = triton::gpu::getThreadOrder(layout);
42-
SmallVector<unsigned> warpOrder(rank);
43-
if (auto enc = dyn_cast<DotOperandEncodingAttr>(layout)) {
44-
warpOrder =
45-
triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1);
46+
auto layout = tensorTy.getEncoding();
47+
auto ll = triton::gpu::toLinearLayout(shape, layout);
48+
assert(ll.has_value() && "Failed to convert layout to linear layout");
49+
auto freeVariableMasks = ll->getFreeVariableMasks();
50+
auto regMasks = freeVariableMasks[kReg];
51+
if (regMasks & regIdx) {
52+
// Step 1: check register redundancy
53+
mask = false_val();
4654
} else {
47-
warpOrder = triton::gpu::getWarpOrder(layout);
48-
}
49-
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
50-
Value warpSize = i32_val(32);
51-
Value laneId = urem(tid, warpSize);
52-
Value warpId = udiv(tid, warpSize);
53-
// TODO: [DOT LL]
54-
// The delinearize function is not entirely correct for certain layouts,
55-
// such as wgmma. The correct approach is to convert a legacy layout to its
56-
// corresponding linear layout and use the linear layout's
57-
// getFreeVariableMasks to identify redundant elements.
58-
SmallVector<Value> multiDimWarpId =
59-
delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
60-
SmallVector<Value> multiDimThreadId =
61-
delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder);
62-
for (unsigned dim = 0; dim < rank; ++dim) {
63-
// if there is no data replication across threads on this dimension
64-
if (shape[dim] >= shapePerCTATile[dim])
65-
continue;
66-
// Otherwise, we need to mask threads that will replicate data on this
67-
// dimension. Calculate the thread index on this dimension for the CTA
68-
Value threadDim =
69-
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
70-
multiDimThreadId[dim]);
71-
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
72-
i32_val(shape[dim])));
73-
}
74-
// Do not write duplicated data when multicast is enabled
75-
if (triton::gpu::getNumCTAs(layout) > 1) {
76-
auto _0 = i32_val(0);
77-
auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout);
78-
auto CTASplitNum = triton::gpu::getCTASplitNum(layout);
79-
auto CTAOrder = triton::gpu::getCTAOrder(layout);
80-
81-
auto multiDimClusterCTAId =
82-
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
83-
84-
for (unsigned dim = 0; dim < rank; ++dim) {
85-
// Skip when multicast is not enabled in this dimension
86-
if (CTAsPerCGA[dim] == CTASplitNum[dim])
87-
continue;
88-
// This wrapping rule must be consistent with emitCTAOffsetForLayout
89-
unsigned splitNum = std::min<unsigned>(shape[dim], CTASplitNum[dim]);
90-
Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum));
91-
// Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]:
92-
// CTA0 and CTA2 holds data of block0,
93-
// CTA1 and CTA3 holds data of block1.
94-
// Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should
95-
// be masked. We add the following mask:
96-
// multiDimClusterCTAId[dim] / splitNum == 0
97-
// Actually in all existing cases of multicast, splitNum is always 1.
98-
// The mask is equivalent to:
99-
// multiDimClusterCTAId[dim] == 0
100-
mask = and_(mask, icmp_eq(repId, _0));
55+
Value warpSize =
56+
i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(moduleOp));
57+
Value laneId = urem(tid, warpSize);
58+
Value warpId = udiv(tid, warpSize);
59+
// Step 2: check lane and warp redundancy
60+
auto laneMasks = freeVariableMasks[kLane];
61+
auto warpMasks = freeVariableMasks[kWarp];
62+
mask = and_(mask, icmp_eq(and_(i32_val(laneMasks), laneId), i32_val(0)));
63+
mask = and_(mask, icmp_eq(and_(i32_val(warpMasks), warpId), i32_val(0)));
64+
if (numCTAs > 1) {
65+
// Step 3: check block redundancy
66+
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);
67+
auto ctaMasks = freeVariableMasks[kBlock];
68+
mask = and_(mask, icmp_eq(and_(i32_val(ctaMasks), ctaId), i32_val(0)));
10169
}
10270
}
10371
} else {
104-
// If the tensor is not ranked, then it is a scalar and only thread 0 of
105-
// CTA0 can write
106-
mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0)));
10772
mask = and_(mask, icmp_eq(tid, i32_val(0)));
73+
if (numCTAs > 1) {
74+
auto ctaId = targetInfo.getClusterCTAId(rewriter, loc);
75+
// If the tensor is not ranked, then it is a scalar and only thread 0 of
76+
// CTA0 within the cluster can write
77+
mask = and_(mask, icmp_eq(ctaId, i32_val(0)));
78+
}
10879
}
10980
return mask;
11081
}
@@ -264,7 +235,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
264235

265236
PTXBuilder ptxBuilder;
266237

267-
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
238+
Value pred = mask ? maskElems[vecStart] : true_val();
268239

269240
const std::string readConstraint =
270241
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
@@ -437,7 +408,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
437408
<< mask << "\n";
438409
}
439410

440-
Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
411+
auto moduleOp = op->getParentOfType<ModuleOp>();
441412
const size_t dtsize =
442413
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
443414
const size_t valueElemNBits = dtsize * 8;
@@ -485,6 +456,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
485456
PTXBuilder ptxBuilder;
486457
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
487458

459+
Value mask = getRedundantDataMask(moduleOp, valueTy, rewriter, loc,
460+
vecStart, targetInfo);
488461
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
489462

490463
auto *asmAddr =
@@ -577,7 +550,6 @@ struct AtomicCASOpConversion
577550
<< " origin vec = " << vecOrig
578551
<< " elemsPerThread = " << elemsPerThread << "\n";
579552

580-
Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
581553
auto vecTy = vec_ty(valueElemTy, vec);
582554
SmallVector<Value> resultVals(elemsPerThread);
583555

@@ -607,6 +579,8 @@ struct AtomicCASOpConversion
607579
os << op.getSem();
608580
auto scope = stringifyMemSyncScope(op.getScope()).str();
609581
atom.global().o(semStr).o(scope).o("cas").o(sTy);
582+
Value mask =
583+
getRedundantDataMask(moduleOp, valueTy, rewriter, loc, i, targetInfo);
610584
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
611585

612586
if (tensorTy) {
@@ -736,12 +710,12 @@ struct AtomicRMWOpConversion
736710
<< " packed = " << packed << " origin vec = " << vecOrig
737711
<< " numElems = " << numElems;
738712

739-
Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
740-
741713
auto packedTy = vec_ty(valueElemTy, packed);
742714
SmallVector<Value> resultVals(elemsPerThread);
743715
for (size_t i = 0; i < elemsPerThread; i += vec * packed) {
744716
Value rmwPtr = ptrElements[i];
717+
Value mask =
718+
getRedundantDataMask(moduleOp, valueTy, rewriter, loc, i, targetInfo);
745719
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
746720
std::string sTy;
747721
PTXBuilder ptxBuilderAtomicRMW;
@@ -976,6 +950,7 @@ struct AsyncCopyGlobalToLocalOpConversion
976950
<< vecBytes << " bytes";
977951
}
978952

953+
auto moduleOp = op->getParentOfType<ModuleOp>();
979954
for (int i = 0; i < shmemAddrs.size(); i++) {
980955
// It's possible that vecTy is larger than 128 bits, in which case we have
981956
// to use multiple cp.async instructions.
@@ -1003,24 +978,26 @@ struct AsyncCopyGlobalToLocalOpConversion
1003978
// if there's any mask. cp.async will automatically fill the
1004979
// remaining slots with 0 if cp-size > src-size.
1005980
// XXX(Keren): Always assume other = 0 for now.
981+
// When 'other != 0' is supported, we will need to fold the
982+
// op.getMask() and redundantDataMask() into the same predicate, the
983+
// way it is done for LoadOp.
1006984
auto selectOp =
1007985
select(maskElems[elemIdx], i32_val(wordBytes), i32_val(0));
1008986
srcSize = ptxBuilder.newOperand(selectOp, "r");
1009987
}
1010988

1011-
// When 'other != 0' is supported, we will need to fold the op.getMask()
1012-
// and redundantDataMask() into the same predicate, the way it is done
1013-
// for LoadOp.
1014-
Value maskVal = redundantDataMask(srcTy, rewriter, loc, targetInfo);
1015-
1016-
// TODO: Masking does not work for CTA multicast with cp.async. This is
1017-
// a quick and dirty workaround to avoid the issue.
1018989
bool skipMaskForMultiCTA = triton::gpu::getNumCTAs(srcLayout) > 1;
1019-
if (!skipMaskForMultiCTA) {
1020-
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
1021-
.predicate(maskVal);
1022-
} else {
990+
if (skipMaskForMultiCTA) {
991+
// TODO: Masking does not work for CTA multicast with cp.async.
992+
// XXX(@peterbell10): In the multi-CTA mode, the redundant data might
993+
// be on different CTAs which don't share the same smem address space,
994+
// so we might need to load the same data multiple times.
1023995
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize);
996+
} else {
997+
Value mask = getRedundantDataMask(moduleOp, srcTy, rewriter, loc,
998+
elemIdx, targetInfo);
999+
copyAsyncOp(dstOperand, srcOperand, copySize, srcSize)
1000+
.predicate(mask);
10241001
}
10251002
ptxBuilder.launch(rewriter, loc, void_ty(getContext()));
10261003
}

0 commit comments

Comments
 (0)