1+ #include " triton/Conversion/TritonGPUToLLVM/Utility.h"
2+ #include " triton/../../lib/Conversion/TritonGPUToLLVM/Utility.cpp"
3+
4+ namespace mlir {
5+ namespace LLVM {
6+
7+ Value createIndexConstant (OpBuilder &builder, Location loc,
8+ TypeConverter *converter, int64_t value) {
9+ Type ty = converter->convertType (builder.getIndexType ());
10+ return builder.create <LLVM::ConstantOp>(loc, ty,
11+ builder.getIntegerAttr (ty, value));
12+ }
13+
14+ SmallVector<Value> getMultiDimOffset (Attribute layout, Location loc,
15+ ConversionPatternRewriter &rewriter,
16+ const TargetInfoBase &targetInfo,
17+ unsigned elemId, RankedTensorType type,
18+ ArrayRef<unsigned > multiDimCTAInRepId,
19+ ArrayRef<unsigned > shapePerCTATile,
20+ bool isTrans, bool stNotRd) {
21+ auto shape = type.getShape ();
22+ unsigned rank = shape.size ();
23+ if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
24+ auto multiDimOffsetFirstElem = emitBaseIndexForLayout (
25+ loc, rewriter, targetInfo, blockedLayout, type, false );
26+ SmallVector<Value> multiDimOffset (rank);
27+ SmallVector<unsigned > multiDimElemId = getMultiDimIndex<unsigned >(
28+ elemId, getSizePerThread (layout), getOrder (layout));
29+ for (unsigned d = 0 ; d < rank; ++d) {
30+ multiDimOffset[d] =
31+ add (multiDimOffsetFirstElem[d],
32+ i32_val (multiDimCTAInRepId[d] * shapePerCTATile[d] +
33+ multiDimElemId[d]));
34+ }
35+ return multiDimOffset;
36+ }
37+ if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
38+ unsigned dim = sliceLayout.getDim ();
39+ auto parentEncoding = sliceLayout.getParent ();
40+ auto parentSizePerThread = getSizePerThread (parentEncoding);
41+ auto parentShape = sliceLayout.paddedShape (shape);
42+ auto parentTy = RankedTensorType::get (parentShape, type.getElementType (),
43+ parentEncoding);
44+ auto offsets = emitOffsetForLayout (layout, type);
45+ auto parentOffset = emitOffsetForLayout (parentEncoding, parentTy);
46+ SmallVector<int > idxs;
47+ for (SmallVector<unsigned > off : offsets) {
48+ off.insert (off.begin () + dim, 0 );
49+ auto it = std::find (parentOffset.begin (), parentOffset.end (), off);
50+ idxs.push_back (std::distance (parentOffset.begin (), it));
51+ }
52+ auto multiDimOffsetParent = getMultiDimOffset (
53+ parentEncoding, loc, rewriter, targetInfo, idxs[elemId], parentTy,
54+ sliceLayout.paddedShape (multiDimCTAInRepId),
55+ sliceLayout.paddedShape (shapePerCTATile));
56+ SmallVector<Value> multiDimOffset (rank);
57+ for (unsigned d = 0 ; d < rank + 1 ; ++d) {
58+ if (d == dim)
59+ continue ;
60+ unsigned slicedD = d < dim ? d : (d - 1 );
61+ multiDimOffset[slicedD] = multiDimOffsetParent[d];
62+ }
63+ return multiDimOffset;
64+ }
65+ if (auto mmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
66+ assert (rank == 2 ||
67+ (rank == 3 && mmaLayout.isAmpere ()) && " Unexpected rank" );
68+ auto shapePerCTA = getShapePerCTA (mmaLayout, shape);
69+ auto instrShape = mmaLayout.getInstrShape ();
70+ SmallVector<Value> mmaColIdx (2 );
71+ SmallVector<Value> mmaRowIdx (2 );
72+ Value threadId = getThreadId (rewriter, loc);
73+ Value warpSize = i32_val (32 );
74+ Value laneId = urem (threadId, warpSize);
75+ Value warpId = udiv (threadId, warpSize);
76+ // TODO: fix the bug in MMAEncodingAttr document
77+ SmallVector<Value> multiDimWarpId (2 );
78+ auto warpsPerCTA = mmaLayout.getWarpsPerCTA ();
79+ auto warpOrder = triton::gpu::getWarpOrder (mmaLayout);
80+ multiDimWarpId = delinearize (rewriter, loc, warpId, warpsPerCTA, warpOrder);
81+ Value _1 = i32_val (1 );
82+ Value _2 = i32_val (2 );
83+ Value _4 = i32_val (4 );
84+ Value _8 = i32_val (8 );
85+ Value _16 = i32_val (16 );
86+ if (mmaLayout.isAmpere () || mmaLayout.isHopper ()) {
87+ multiDimWarpId[rank - 1 ] = urem (
88+ multiDimWarpId[rank - 1 ],
89+ i32_val (ceil<unsigned >(shapePerCTA[rank - 1 ], instrShape[rank - 1 ])));
90+ multiDimWarpId[rank - 2 ] = urem (
91+ multiDimWarpId[rank - 2 ],
92+ i32_val (ceil<unsigned >(shapePerCTA[rank - 2 ], instrShape[rank - 2 ])));
93+
94+ Value mmaGrpId = udiv (laneId, _4);
95+ Value mmaGrpIdP8 = add (mmaGrpId, _8);
96+ Value mmaThreadIdInGrp = urem (laneId, _4);
97+ Value mmaThreadIdInGrpM2 = mul (mmaThreadIdInGrp, _2);
98+ Value mmaThreadIdInGrpM2P1 = add (mmaThreadIdInGrpM2, _1);
99+ Value rowWarpOffset =
100+ mul (multiDimWarpId[rank - 2 ], i32_val (instrShape[rank - 2 ]));
101+ mmaRowIdx[0 ] = add (mmaGrpId, rowWarpOffset);
102+ mmaRowIdx[1 ] = add (mmaGrpIdP8, rowWarpOffset);
103+ Value colWarpOffset =
104+ mul (multiDimWarpId[rank - 1 ], i32_val (instrShape[rank - 1 ]));
105+ mmaColIdx[0 ] = add (mmaThreadIdInGrpM2, colWarpOffset);
106+ mmaColIdx[1 ] = add (mmaThreadIdInGrpM2P1, colWarpOffset);
107+ } else if (mmaLayout.isVolta ()) {
108+ // Volta doesn't follow the pattern here.
109+ } else {
110+ llvm_unreachable (" Unexpected MMALayout version" );
111+ }
112+
113+ SmallVector<Value> multiDimOffset (rank);
114+ if (mmaLayout.isHopper ()) {
115+ unsigned elemIdRem4 = elemId % 4 ;
116+ unsigned nGrpId = elemId / 4 ;
117+ multiDimOffset[0 ] = elemIdRem4 < 2 ? mmaRowIdx[0 ] : mmaRowIdx[1 ];
118+ multiDimOffset[1 ] = elemIdRem4 % 2 == 0 ? mmaColIdx[0 ] : mmaColIdx[1 ];
119+ multiDimOffset[1 ] = add (multiDimOffset[1 ], i32_val (8 * nGrpId));
120+ multiDimOffset[0 ] = add (multiDimOffset[0 ], i32_val (multiDimCTAInRepId[0 ] *
121+ shapePerCTATile[0 ]));
122+ multiDimOffset[1 ] = add (multiDimOffset[1 ], i32_val (multiDimCTAInRepId[1 ] *
123+ shapePerCTATile[1 ]));
124+ } else if (mmaLayout.isAmpere ()) {
125+ if (rank == 3 )
126+ multiDimOffset[0 ] =
127+ add (multiDimWarpId[0 ],
128+ i32_val (multiDimCTAInRepId[0 ] * shapePerCTATile[0 ]));
129+ multiDimOffset[rank - 2 ] = elemId < 2 ? mmaRowIdx[0 ] : mmaRowIdx[1 ];
130+ multiDimOffset[rank - 1 ] = elemId % 2 == 0 ? mmaColIdx[0 ] : mmaColIdx[1 ];
131+ multiDimOffset[rank - 2 ] =
132+ add (multiDimOffset[rank - 2 ], i32_val (multiDimCTAInRepId[rank - 2 ] *
133+ shapePerCTATile[rank - 2 ]));
134+ multiDimOffset[rank - 1 ] =
135+ add (multiDimOffset[rank - 1 ], i32_val (multiDimCTAInRepId[rank - 1 ] *
136+ shapePerCTATile[rank - 1 ]));
137+ } else if (mmaLayout.isVolta ()) {
138+ auto [isARow, isBRow, isAVec4, isBVec4, _] =
139+ mmaLayout.decodeVoltaLayoutStates ();
140+ auto coords = SharedToDotOperandMMAv1::getMNCoords (
141+ threadId, loc, rewriter, mmaLayout.getWarpsPerCTA (), mmaLayout, shape,
142+ isARow, isBRow, isAVec4, isBVec4);
143+ return coords[elemId];
144+ } else {
145+ llvm_unreachable (" Unexpected MMALayout version" );
146+ }
147+ return multiDimOffset;
148+ }
149+ if (auto mmaLayout = mlir::dyn_cast<IluvatarMmaEncodingAttr>(layout)) {
150+ assert (rank == 2 && " Unexpected rank" );
151+ SmallVector<Value> multiDimOffset (rank);
152+ Value threadId = getThreadId (rewriter, loc);
153+ if (mmaLayout.isVolta ()) {
154+ int bitwidth = type.getElementType ().getIntOrFloatBitWidth ();
155+ int elemVecSize = stNotRd ? (32 / bitwidth) : 1 ;
156+ static auto func = SharedToDotOperandMMAv1::load_getMNCoords_func (
157+ " iluvatar" , " getMNCoords" );
158+ auto coords = func (threadId, loc, rewriter, mmaLayout.getWarpsPerCTA (),
159+ mmaLayout, shape, bitwidth, elemVecSize, isTrans);
160+ return coords[elemId];
161+ } else {
162+ llvm_unreachable (" Unexpected MMALayout version" );
163+ }
164+ }
165+ if (isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(layout)) {
166+ auto multiDimBase =
167+ emitBaseIndexForLayout (loc, rewriter, targetInfo, layout, type, false );
168+ SmallVector<SmallVector<unsigned >> offsets;
169+ assert (rank == 2 );
170+ SmallVector<Value> multiDimOffset (rank);
171+ if (auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout)) {
172+ emitMfmaOffsetForCTA (mfmaLayout, offsets, 0 , multiDimCTAInRepId[0 ],
173+ multiDimCTAInRepId[1 ]);
174+ } else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout)) {
175+ emitWmmaOffsetForCTA (wmmaLayout, offsets, 0 , multiDimCTAInRepId[0 ],
176+ multiDimCTAInRepId[1 ]);
177+ }
178+ multiDimOffset[0 ] = add (multiDimBase[0 ], i32_val (offsets[elemId][0 ]));
179+ multiDimOffset[1 ] = add (multiDimBase[1 ], i32_val (offsets[elemId][1 ]));
180+ return multiDimOffset;
181+ }
182+ llvm_unreachable (" unexpected layout in getMultiDimOffset" );
183+ }
184+
185+ } // namespace LLVM
186+ } // namespace mlir
0 commit comments