33#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
44#include " llvm/Support/ErrorHandling.h"
55
6- using ValueTable = std::map<std::pair <int , int >, Value>;
6+ using ValueTable = std::map<std::array <int , 3 >, Value>;
77using mlir::triton::gpu::getShapePerCTA;
88using mlir::triton::gpu::SharedEncodingAttr;
99using mlir::triton::gpu::intel::DpasEncodingAttr;
@@ -44,9 +44,9 @@ template <unsigned opIdx> class DpasMatmulLoader {
4444 SmallVector<Value> computeLdsMatOffs (Value warpOff, Value lane,
4545 Value cSwizzleOffset);
4646 // Load the matrix value.
47- Value loadMatrix (int repOuter , int repInner, const ArrayRef<Value> ptrs ,
48- LLVM::LLVMStructType structTy, Type smemTy ,
49- Value cSwizzleOffset) const ;
47+ Value loadMatrix (int repBatch , int repOuter, int repInner ,
48+ const ArrayRef<Value> ptrs, LLVM::LLVMStructType structTy,
49+ Type smemTy, Value cSwizzleOffset) const ;
5050
5151private:
5252 unsigned getThreadsPerWarp () const {
@@ -57,6 +57,7 @@ template <unsigned opIdx> class DpasMatmulLoader {
5757 MemDescType descTy;
5858
5959 SmallVector<Value> smemStrides;
60+ Value repBatchDimStride;
6061 Value repNonKDimStride;
6162 Value repKDimStride;
6263
@@ -176,19 +177,19 @@ DpasMatmulLoader<opIdx>::computeLdsMatOffs(Value warpId, Value laneId,
176177}
177178
178179template <unsigned opIdx>
179- Value DpasMatmulLoader<opIdx>::loadMatrix(int repOuter, int repInner,
180- const ArrayRef<Value> ptrs,
181- LLVM::LLVMStructType structTy,
182- Type smemTy,
183- Value cSwizzleOffset) const {
180+ Value DpasMatmulLoader<opIdx>::loadMatrix(
181+ int repBatch, int repOuter, int repInner, const ArrayRef<Value> ptrs,
182+ LLVM::LLVMStructType structTy, Type smemTy, Value cSwizzleOffset) const {
184183 Type elemTy = structTy.getBody ()[0 ];
185184 assert (
186185 llvm::any_of (structTy.getBody (), [&](Type ty) { return ty == elemTy; }) &&
187186 " The struct should have the same element types." );
188187
188+ Value offsetBatch = mul (i32_val (repBatch), repBatchDimStride);
189189 Value offsetOuter = mul (i32_val (repOuter), repNonKDimStride);
190190 Value offsetInner = mul (i32_val (repInner), repKDimStride);
191191 Value offset = add (offsetOuter, offsetInner);
192+ offset = add (offset, offsetBatch);
192193
193194 Value llvmStruct = rewriter.create <LLVM::UndefOp>(loc, structTy);
194195 size_t elemNum = structTy.getBody ().size ();
@@ -203,18 +204,20 @@ Value DpasMatmulLoader<opIdx>::loadMatrix(int repOuter, int repInner,
203204}
204205
205206Value composeValuesToDotOperandLayoutStruct (
206- const ValueTable &vals, int n0, int n1,
207+ const ValueTable &vals, int batch, int n0, int n1,
207208 const LLVMTypeConverter *typeConverter, Location loc,
208209 ConversionPatternRewriter &rewriter) {
209210 std::vector<Value> elems;
210- for (int m = 0 ; m < n0; ++m) {
211- for (int k = 0 ; k < n1; ++k) {
212- Value matVal = vals.at ({m, k});
213- auto matType = cast<LLVM::LLVMStructType>(matVal.getType ());
214- Type valTy = matType.getBody ()[0 ];
215- for (int i = 0 ; i < matType.getBody ().size (); ++i) {
216- auto val = extract_val (valTy, matVal, i);
217- elems.push_back (val);
211+ for (int b = 0 ; b < batch; ++b) {
212+ for (int m = 0 ; m < n0; ++m) {
213+ for (int k = 0 ; k < n1; ++k) {
214+ Value matVal = vals.at ({b, m, k});
215+ auto matType = cast<LLVM::LLVMStructType>(matVal.getType ());
216+ Type valTy = matType.getBody ()[0 ];
217+ for (int i = 0 ; i < matType.getBody ().size (); ++i) {
218+ auto val = extract_val (valTy, matVal, i);
219+ elems.push_back (val);
220+ }
218221 }
219222 }
220223 }
@@ -245,7 +248,7 @@ Type getSharedMemTy(Type argType) {
245248}
246249
247250template <unsigned opIdx>
248- std::function<void (int , int )>
251+ std::function<void (int , int , int )>
249252getLoadMatrixFn (MemDescType descTy, const SharedMemoryObject &smemObj,
250253 DpasEncodingAttr dpasLayout, unsigned warpsPerTile,
251254 SmallVector<unsigned > instrShape, Value warpId,
@@ -261,7 +264,8 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
261264 ArrayRef<unsigned > order = sharedLayout.getOrder ();
262265
263266 // (a, b) is the coordinate.
264- auto load = [=, &rewriter, &smemObj, &instrShape, &vals](int a, int b) {
267+ auto load = [=, &rewriter, &smemObj, &instrShape, &vals](int batch, int outer,
268+ int inner) {
265269 DpasMatmulLoader<opIdx> loader (dpasLayout, descTy, warpsPerTile,
266270 smemObj.strides , instrShape, rewriter,
267271 typeConverter, loc);
@@ -289,7 +293,8 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
289293 SmallVector<Type>(totalElem / threadsPerWarp,
290294 typeConverter->convertType (eltTy)));
291295
292- vals[{a, b}] = loader.loadMatrix (a, b, ptrs, matTy, smemTy, cSwizzleOffset);
296+ vals[{batch, outer, inner}] = loader.loadMatrix (
297+ batch, outer, inner, ptrs, matTy, smemTy, cSwizzleOffset);
293298 };
294299
295300 return load;
@@ -325,27 +330,32 @@ Value loadOperand(ConversionPatternRewriter &rewriter, Location loc,
325330 LLVM::delinearize (rewriter, loc, warpId, warpsPerCTA, order);
326331
327332 // FIXME: Using opIdx as the dimIdx will be incorrect in 3D case.
328- unsigned ceilRes = mlir::ceil<unsigned >(shapePerCTA[opIdx], shape[opIdx]);
329- Value outerWarpDim = urem (multiDimWarpId[opIdx], i32_val (ceilRes));
330- unsigned warpsPerTile = std::min<unsigned >(warpsPerCTA[opIdx], ceilRes);
333+ unsigned rank = shape.size ();
334+ unsigned dimOuter = opIdx ? (rank - 1 ) : (rank - 2 );
335+ unsigned ceilRes =
336+ mlir::ceil<unsigned >(shapePerCTA[dimOuter], shape[dimOuter]);
337+ Value outerWarpDim = urem (multiDimWarpId[dimOuter], i32_val (ceilRes));
338+ unsigned warpsPerTile = std::min<unsigned >(warpsPerCTA[dimOuter], ceilRes);
331339
332340 // Get the function to use to load the operand.
333341 ValueTable vals;
334- std::function<void (int , int )> loadFn = getLoadMatrixFn<opIdx>(
342+ std::function<void (int , int , int )> loadFn = getLoadMatrixFn<opIdx>(
335343 descTy, smemObj, dpasLayout, warpsPerTile, std::move (shape), warpId,
336344 outerWarpDim, laneId, vals, typeConverter, rewriter, loc);
337345
338346 // Load the operand.
339- int64_t numRepOuter = numReps[opIdx];
340- int64_t numRepK = numReps[(opIdx == 0 ) ? 1 : 0 ];
347+ int64_t numRepBatch = numReps[0 ];
348+ int64_t numRepOuter = numReps[opIdx ? 2 : 1 ];
349+ int64_t numRepK = numReps[opIdx ? 1 : 2 ];
341350
342- for (int m = 0 ; m < numRepOuter; ++m)
343- for (int k = 0 ; k < numRepK; ++k)
344- loadFn (m, k);
351+ for (int b = 0 ; b < numRepBatch; ++b)
352+ for (int m = 0 ; m < numRepOuter; ++m)
353+ for (int k = 0 ; k < numRepK; ++k)
354+ loadFn (b, m, k);
345355
346356 // Format the values into an LLVM::Struct.
347- return composeValuesToDotOperandLayoutStruct (vals, numRepOuter, numRepK,
348- typeConverter, loc, rewriter);
357+ return composeValuesToDotOperandLayoutStruct (
358+ vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter);
349359}
350360
351361} // namespace
0 commit comments