11#include " ../TritonGPUToLLVMBase.h"
22#include " ../Utility.h"
33#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
4+ #include " mlir/Support/LLVM.h"
5+ #include " triton/Dialect/TritonGPU/IR/Dialect.h"
46#include " llvm/Support/ErrorHandling.h"
57
68using ValueTable = std::map<std::array<int , 3 >, Value>;
@@ -16,16 +18,18 @@ template <unsigned opIdx> class DpasMatmulLoader {
1618 DpasMatmulLoader (DpasEncodingAttr dpasLayout, MemDescType descTy,
1719 unsigned warpsPerTile, ArrayRef<Value> smemStrides,
1820 const SmallVector<unsigned > &warpShape,
21+ SmallVector<Value> multiDimWarpId,
1922 ConversionPatternRewriter &rewriter,
2023 const LLVMTypeConverter *typeConverter, Location loc)
2124 : dpasLayout(dpasLayout), descTy(descTy), smemStrides(smemStrides),
22- rewriter (rewriter), loc(loc) {
25+ multiDimWarpId (multiDimWarpId), rewriter(rewriter), loc(loc) {
2326 static_assert (opIdx == 0 || opIdx == 1 );
2427
2528 size_t rank = warpShape.size ();
26- unsigned kDim = (opIdx == 0 ) ? rank - 1 : rank - 2 ;
27- unsigned nonKDim = (opIdx == 0 ) ? rank - 2 : rank - 1 ;
28- repBatchDimStride = rank == 3 ? smemStrides[0 ] : i32_val (1 );
29+ unsigned kDim = opIdx ? rank - 2 : rank - 1 ;
30+ unsigned nonKDim = opIdx ? rank - 1 : rank - 2 ;
31+ // Assume that smem is create with layout offset {2, 1, 0}
32+ repBatchDimStride = smemStrides[0 ];
2933 repKDimStride = mul (i32_val (warpShape[kDim ]), smemStrides[kDim ]);
3034 repNonKDimStride =
3135 mul (i32_val (warpShape[nonKDim] * warpsPerTile), smemStrides[nonKDim]);
@@ -60,6 +64,7 @@ template <unsigned opIdx> class DpasMatmulLoader {
6064 MemDescType descTy;
6165
6266 SmallVector<Value> smemStrides;
67+ SmallVector<Value> multiDimWarpId;
6368 Value repBatchDimStride;
6469 Value repNonKDimStride;
6570 Value repKDimStride;
@@ -133,6 +138,7 @@ DpasMatmulLoader<opIdx>::computeLdsMatOffs(Value warpId, Value laneId,
133138 SmallVector<unsigned > instShape = opIdx == 0 ? dpasLayout.getDPASInstShapeA ()
134139 : dpasLayout.getDPASInstShapeB ();
135140 ArrayRef<int64_t > shareMemoryShape = descTy.getShape ();
141+ SmallVector<int64_t > shapePerCTA = getShapePerCTA (descTy);
136142
137143 SmallVector<Value> offs (numPtrs);
138144 const unsigned repClusterSize =
@@ -160,8 +166,8 @@ DpasMatmulLoader<opIdx>::computeLdsMatOffs(Value warpId, Value laneId,
160166
161167 // round the offset into the tensor's shape limitation. (Rounded
162168 // broadcast)
163- iBase = urem (iBase, i32_val (shareMemoryShape[0 ]));
164- jBase = urem (jBase, i32_val (shareMemoryShape[1 ]));
169+ iBase = urem (iBase, i32_val (shareMemoryShape[rank - 2 ]));
170+ jBase = urem (jBase, i32_val (shareMemoryShape[rank - 1 ]));
165171
166172 // inner index offset
167173 Value jOff = zeroVal;
@@ -170,10 +176,17 @@ DpasMatmulLoader<opIdx>::computeLdsMatOffs(Value warpId, Value laneId,
170176 jOff = add (jOff, udiv (cSwizzleOffset, vecVal));
171177 jOff = mul (xor_ (jOff, phase), vecVal);
172178
173- Value i = add (mul (iBase, smemStrides[0 ]), iOff);
174- Value j = add (mul (jBase, smemStrides[1 ]), jOff);
179+ Value i = add (mul (iBase, smemStrides[rank - 2 ]), iOff);
180+ Value j = add (mul (jBase, smemStrides[rank - 1 ]), jOff);
175181
176- offs[index++] = add (i, j);
182+ Value baseOff;
183+ if (shapePerCTA.size () == 3 && shapePerCTA[0 ] > 1 ) {
184+ Value batchOffset =
185+ mul (multiDimWarpId[0 ], i32_val (shapePerCTA[1 ] * shapePerCTA[2 ]));
186+ offs[index++] = add (batchOffset, add (i, j));
187+ } else {
188+ offs[index++] = add (i, j);
189+ }
177190 }
178191 }
179192 }
@@ -190,13 +203,14 @@ Value DpasMatmulLoader<opIdx>::loadMatrix(
190203 llvm::any_of (structTy.getBody (), [&](Type ty) { return ty == elemTy; }) &&
191204 " The struct should have the same element types." );
192205
193- Value offsetBatch = mul (i32_val (repBatch), repBatchDimStride);
194206 Value offsetOuter = mul (i32_val (repOuter), repNonKDimStride);
195207 Value offsetInner = mul (i32_val (repInner), repKDimStride);
196208 Value offset = add (offsetOuter, offsetInner);
197- // FIXME: repBatchSize and
209+ SmallVector<unsigned > warpsPerCTA = dpasLayout.getWarpsPerCTA ();
210+ // 3DTODO: check if repBatch * warpsPerCTA[0] is correct for the offset.
198211 if (repBatch > 0 ) {
199- Value offsetBatch = mul (i32_val (repBatch), repBatchDimStride);
212+ Value offsetBatch =
213+ mul (i32_val (repBatch * warpsPerCTA[0 ]), repBatchDimStride);
200214 offset = add (offset, offsetBatch);
201215 }
202216
@@ -260,7 +274,8 @@ template <unsigned opIdx>
260274std::function<void (int , int , int )>
261275getLoadMatrixFn (MemDescType descTy, const SharedMemoryObject &smemObj,
262276 DpasEncodingAttr dpasLayout, unsigned warpsPerTile,
263- SmallVector<unsigned > instrShape, Value warpId,
277+ SmallVector<unsigned > shapePerWarp,
278+ SmallVector<Value> multiDimWarpId, Value warpId,
264279 Value outerWarpDim, Value laneId, ValueTable &vals,
265280 const LLVMTypeConverter *typeConverter,
266281 ConversionPatternRewriter &rewriter, Location loc) {
@@ -271,13 +286,14 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
271286
272287 auto sharedLayout = cast<SharedEncodingAttr>(descTy.getEncoding ());
273288 ArrayRef<unsigned > order = sharedLayout.getOrder ();
289+ unsigned rank = order.size ();
274290
275291 // (a, b) is the coordinate.
276- auto load = [=, &rewriter, &smemObj, &instrShape , &vals]( int batch, int outer ,
277- int inner) {
278- DpasMatmulLoader<opIdx> loader (dpasLayout, descTy, warpsPerTile,
279- smemObj.strides , instrShape, rewriter ,
280- typeConverter, loc);
292+ auto load = [=, &rewriter, &smemObj, &shapePerWarp , &multiDimWarpId ,
293+ &vals]( int batch, int outer, int inner) {
294+ DpasMatmulLoader<opIdx> loader (
295+ dpasLayout, descTy, warpsPerTile, smemObj.strides , shapePerWarp ,
296+ multiDimWarpId, rewriter, typeConverter, loc);
281297
282298 // Offset of a slice within the original tensor in shared memory.
283299 Value cSwizzleOffset = smemObj.getCSwizzleOffset (order[0 ]);
@@ -295,7 +311,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
295311 gep (ptr_ty (rewriter.getContext (), 3 ), smemTy, smemBase, offs[i]);
296312
297313 // Load from shared memory.
298- unsigned totalElem = product<unsigned >(instrShape );
314+ unsigned totalElem = product<unsigned >(shapePerWarp );
299315 unsigned threadsPerWarp = product<unsigned >(getThreadsPerWarp (dpasLayout));
300316 auto matTy = LLVM::LLVMStructType::getLiteral (
301317 eltTy.getContext (),
@@ -348,8 +364,9 @@ Value loadOperand(ConversionPatternRewriter &rewriter, Location loc,
348364 // Get the function to use to load the operand.
349365 ValueTable vals;
350366 std::function<void (int , int , int )> loadFn = getLoadMatrixFn<opIdx>(
351- descTy, smemObj, dpasLayout, warpsPerTile, std::move (shape), warpId,
352- outerWarpDim, laneId, vals, typeConverter, rewriter, loc);
367+ descTy, smemObj, dpasLayout, warpsPerTile, std::move (shape),
368+ std::move (multiDimWarpId), warpId, outerWarpDim, laneId, vals,
369+ typeConverter, rewriter, loc);
353370
354371 // Load the operand.
355372 int64_t numRepBatch = numReps[0 ];
0 commit comments