88#include " Utility.h"
99#include " triton/Conversion/TritonGPUToLLVM/Utility.h"
1010#include " triton/Dialect/Triton/IR/Dialect.h"
11+ #include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1112#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
1213
1314using namespace mlir ;
@@ -24,87 +25,57 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
2425namespace {
2526
2627// Return the mask for the unique data accessed by given tensor type.
27- // Used to mask out the redundant data accessed by threads.
28- Value redundantDataMask (Type valueTy, ConversionPatternRewriter &rewriter,
29- Location loc, const NVIDIA::TargetInfo &targetInfo) {
28+ // NOTE: Redundant memory load is allowed in triton, but redundant memory store
29+ // is not allowed.
30+ // mask = true: thread can write
31+ // mask = false: thread should not write
32+ Value getRedundantDataMask (ModuleOp moduleOp, Type valueTy,
33+ ConversionPatternRewriter &rewriter, Location loc,
34+ int regIdx, const NVIDIA::TargetInfo &targetInfo) {
35+ auto ctx = moduleOp.getContext ();
3036 auto tensorTy = dyn_cast<RankedTensorType>(valueTy);
31- Value mask = int_val ( 1 , 1 );
37+ auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs (moduleOp );
3238 auto tid = tid_val ();
33- auto clusterCTAId = targetInfo.getClusterCTAId (rewriter, loc);
39+ auto mask = true_val ();
40+ auto kReg = str_attr (" register" );
41+ auto kLane = str_attr (" lane" );
42+ auto kWarp = str_attr (" warp" );
43+ auto kBlock = str_attr (" block" );
3444 if (tensorTy) {
35- auto layout = tensorTy.getEncoding ();
3645 auto shape = tensorTy.getShape ();
37- unsigned rank = shape.size ();
38- auto sizePerThread = triton::gpu::getSizePerThread (layout);
39- auto threadsPerWarp = triton::gpu::getThreadsPerWarp (layout);
40- auto warpsPerCTA = triton::gpu::getWarpsPerCTA (layout);
41- auto threadOrder = triton::gpu::getThreadOrder (layout);
42- SmallVector<unsigned > warpOrder (rank);
43- if (auto enc = dyn_cast<DotOperandEncodingAttr>(layout)) {
44- warpOrder =
45- triton::gpu::getMatrixOrder (rank, /* rowMajor=*/ enc.getOpIdx () == 1 );
46+ auto layout = tensorTy.getEncoding ();
47+ auto ll = triton::gpu::toLinearLayout (shape, layout);
48+ assert (ll.has_value () && " Failed to convert layout to linear layout" );
49+ auto freeVariableMasks = ll->getFreeVariableMasks ();
50+ auto regMasks = freeVariableMasks[kReg ];
51+ if (regMasks & regIdx) {
52+ // Step 1: check register redundancy
53+ mask = false_val ();
4654 } else {
47- warpOrder = triton::gpu::getWarpOrder (layout);
48- }
49- auto shapePerCTATile = triton::gpu::getShapePerCTATile (layout);
50- Value warpSize = i32_val (32 );
51- Value laneId = urem (tid, warpSize);
52- Value warpId = udiv (tid, warpSize);
53- // TODO: [DOT LL]
54- // The delinearize function is not entirely correct for certain layouts,
55- // such as wgmma. The correct approach is to convert a legacy layout to its
56- // corresponding linear layout and use the linear layout's
57- // getFreeVariableMasks to identify redundant elements.
58- SmallVector<Value> multiDimWarpId =
59- delinearize (rewriter, loc, warpId, warpsPerCTA, warpOrder);
60- SmallVector<Value> multiDimThreadId =
61- delinearize (rewriter, loc, laneId, threadsPerWarp, threadOrder);
62- for (unsigned dim = 0 ; dim < rank; ++dim) {
63- // if there is no data replication across threads on this dimension
64- if (shape[dim] >= shapePerCTATile[dim])
65- continue ;
66- // Otherwise, we need to mask threads that will replicate data on this
67- // dimension. Calculate the thread index on this dimension for the CTA
68- Value threadDim =
69- add (mul (multiDimWarpId[dim], i32_val (threadsPerWarp[dim])),
70- multiDimThreadId[dim]);
71- mask = and_ (mask, icmp_slt (mul (threadDim, i32_val (sizePerThread[dim])),
72- i32_val (shape[dim])));
73- }
74- // Do not write duplicated data when multicast is enabled
75- if (triton::gpu::getNumCTAs (layout) > 1 ) {
76- auto _0 = i32_val (0 );
77- auto CTAsPerCGA = triton::gpu::getCTAsPerCGA (layout);
78- auto CTASplitNum = triton::gpu::getCTASplitNum (layout);
79- auto CTAOrder = triton::gpu::getCTAOrder (layout);
80-
81- auto multiDimClusterCTAId =
82- delinearize (rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
83-
84- for (unsigned dim = 0 ; dim < rank; ++dim) {
85- // Skip when multicast is not enabled in this dimension
86- if (CTAsPerCGA[dim] == CTASplitNum[dim])
87- continue ;
88- // This wrapping rule must be consistent with emitCTAOffsetForLayout
89- unsigned splitNum = std::min<unsigned >(shape[dim], CTASplitNum[dim]);
90- Value repId = udiv (multiDimClusterCTAId[dim], i32_val (splitNum));
91- // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]:
92- // CTA0 and CTA2 holds data of block0,
93- // CTA1 and CTA3 holds data of block1.
94- // Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should
95- // be masked. We add the following mask:
96- // multiDimClusterCTAId[dim] / splitNum == 0
97- // Actually in all existing cases of multicast, splitNum is always 1.
98- // The mask is equivalent to:
99- // multiDimClusterCTAId[dim] == 0
100- mask = and_ (mask, icmp_eq (repId, _0));
55+ Value warpSize =
56+ i32_val (triton::gpu::TritonGPUDialect::getThreadsPerWarp (moduleOp));
57+ Value laneId = urem (tid, warpSize);
58+ Value warpId = udiv (tid, warpSize);
59+ // Step 2: check lane and warp redundancy
60+ auto laneMasks = freeVariableMasks[kLane ];
61+ auto warpMasks = freeVariableMasks[kWarp ];
62+ mask = and_ (mask, icmp_eq (and_ (i32_val (laneMasks), laneId), i32_val (0 )));
63+ mask = and_ (mask, icmp_eq (and_ (i32_val (warpMasks), warpId), i32_val (0 )));
64+ if (numCTAs > 1 ) {
65+ // Step 3: check block redundancy
66+ auto ctaId = targetInfo.getClusterCTAId (rewriter, loc);
67+ auto ctaMasks = freeVariableMasks[kBlock ];
68+ mask = and_ (mask, icmp_eq (and_ (i32_val (ctaMasks), ctaId), i32_val (0 )));
10169 }
10270 }
10371 } else {
104- // If the tensor is not ranked, then it is a scalar and only thread 0 of
105- // CTA0 can write
106- mask = and_ (mask, icmp_eq (clusterCTAId, i32_val (0 )));
10772 mask = and_ (mask, icmp_eq (tid, i32_val (0 )));
73+ if (numCTAs > 1 ) {
74+ auto ctaId = targetInfo.getClusterCTAId (rewriter, loc);
75+ // If the tensor is not ranked, then it is a scalar and only thread 0 of
76+ // CTA0 within the cluster can write
77+ mask = and_ (mask, icmp_eq (ctaId, i32_val (0 )));
78+ }
10879 }
10980 return mask;
11081}
@@ -264,7 +235,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
264235
265236 PTXBuilder ptxBuilder;
266237
267- Value pred = mask ? maskElems[vecStart] : int_val ( 1 , 1 );
238+ Value pred = mask ? maskElems[vecStart] : true_val ( );
268239
269240 const std::string readConstraint =
270241 (width == 64 ) ? " l" : ((width == 32 ) ? " r" : " c" );
@@ -437,7 +408,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
437408 << mask << " \n " ;
438409 }
439410
440- Value mask = redundantDataMask (valueTy, rewriter, loc, targetInfo );
411+ auto moduleOp = op-> getParentOfType <ModuleOp>( );
441412 const size_t dtsize =
442413 std::max<int >(1 , valueElemTy.getIntOrFloatBitWidth () / 8 );
443414 const size_t valueElemNBits = dtsize * 8 ;
@@ -485,6 +456,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
485456 PTXBuilder ptxBuilder;
486457 auto *asmArgList = ptxBuilder.newListOperand (asmArgs);
487458
459+ Value mask = getRedundantDataMask (moduleOp, valueTy, rewriter, loc,
460+ vecStart, targetInfo);
488461 Value maskVal = llMask ? and_ (mask, maskElems[vecStart]) : mask;
489462
490463 auto *asmAddr =
@@ -577,7 +550,6 @@ struct AtomicCASOpConversion
577550 << " origin vec = " << vecOrig
578551 << " elemsPerThread = " << elemsPerThread << " \n " ;
579552
580- Value mask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
581553 auto vecTy = vec_ty (valueElemTy, vec);
582554 SmallVector<Value> resultVals (elemsPerThread);
583555
@@ -607,6 +579,8 @@ struct AtomicCASOpConversion
607579 os << op.getSem ();
608580 auto scope = stringifyMemSyncScope (op.getScope ()).str ();
609581 atom.global ().o (semStr).o (scope).o (" cas" ).o (sTy );
582+ Value mask =
583+ getRedundantDataMask (moduleOp, valueTy, rewriter, loc, i, targetInfo);
610584 atom (dstOpr, ptrOpr, cmpOpr, valOpr).predicate (mask);
611585
612586 if (tensorTy) {
@@ -736,12 +710,12 @@ struct AtomicRMWOpConversion
736710 << " packed = " << packed << " origin vec = " << vecOrig
737711 << " numElems = " << numElems;
738712
739- Value mask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
740-
741713 auto packedTy = vec_ty (valueElemTy, packed);
742714 SmallVector<Value> resultVals (elemsPerThread);
743715 for (size_t i = 0 ; i < elemsPerThread; i += vec * packed) {
744716 Value rmwPtr = ptrElements[i];
717+ Value mask =
718+ getRedundantDataMask (moduleOp, valueTy, rewriter, loc, i, targetInfo);
745719 Value rmwMask = llMask ? and_ (mask, maskElements[i]) : mask;
746720 std::string sTy ;
747721 PTXBuilder ptxBuilderAtomicRMW;
@@ -976,6 +950,7 @@ struct AsyncCopyGlobalToLocalOpConversion
976950 << vecBytes << " bytes" ;
977951 }
978952
953+ auto moduleOp = op->getParentOfType <ModuleOp>();
979954 for (int i = 0 ; i < shmemAddrs.size (); i++) {
980955 // It's possible that vecTy is larger than 128 bits, in which case we have
981956 // to use multiple cp.async instructions.
@@ -1003,24 +978,26 @@ struct AsyncCopyGlobalToLocalOpConversion
1003978 // if there's any mask. cp.async will automatically fill the
1004979 // remaining slots with 0 if cp-size > src-size.
1005980 // XXX(Keren): Always assume other = 0 for now.
981+ // When 'other != 0' is supported, we will need to fold the
982+ // op.getMask() and redundantDataMask() into the same predicate, the
983+ // way it is done for LoadOp.
1006984 auto selectOp =
1007985 select (maskElems[elemIdx], i32_val (wordBytes), i32_val (0 ));
1008986 srcSize = ptxBuilder.newOperand (selectOp, " r" );
1009987 }
1010988
1011- // When 'other != 0' is supported, we will need to fold the op.getMask()
1012- // and redundantDataMask() into the same predicate, the way it is done
1013- // for LoadOp.
1014- Value maskVal = redundantDataMask (srcTy, rewriter, loc, targetInfo);
1015-
1016- // TODO: Masking does not work for CTA multicast with cp.async. This is
1017- // a quick and dirty workaround to avoid the issue.
1018989 bool skipMaskForMultiCTA = triton::gpu::getNumCTAs (srcLayout) > 1 ;
1019- if (!skipMaskForMultiCTA) {
1020- copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1021- .predicate (maskVal);
1022- } else {
990+ if (skipMaskForMultiCTA) {
991+ // TODO: Masking does not work for CTA multicast with cp.async.
992+ // XXX(@peterbell10): In the multi-CTA mode, the redundant data might
993+ // be on different CTAs which don't share the same smem address space,
994+ // so we might need to load the same data multiple times.
1023995 copyAsyncOp (dstOperand, srcOperand, copySize, srcSize);
996+ } else {
997+ Value mask = getRedundantDataMask (moduleOp, srcTy, rewriter, loc,
998+ elemIdx, targetInfo);
999+ copyAsyncOp (dstOperand, srcOperand, copySize, srcSize)
1000+ .predicate (mask);
10241001 }
10251002 ptxBuilder.launch (rewriter, loc, void_ty (getContext ()));
10261003 }
0 commit comments