@@ -15,6 +15,9 @@ def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
1515 ];
1616}
1717
18+ def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
19+
20+
1821class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
1922 Dialect dialect = TritonGPU_Dialect,
2023 string baseCppClass = "::mlir::Attribute">
@@ -309,46 +312,54 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
309312 if(!mmaEnc)
310313 return get(context, 1, 1, 1, order, CTALayout);
311314
315+ int opIdx = dotOpEnc.getOpIdx();
316+ auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
317+
318+ // number of rows per phase
319+
320+ // index of the inner dimension in `order`
321+ unsigned inner = (opIdx == 0) ? 0 : 1;
322+
312323 // ---- begin Ampere & Hopper ----
313324 if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
314- return get(context, dotOpEnc.getOpIdx(), dotOpEnc.getKWidth(), shape, order, CTALayout, typeWidthInBit, needTrans);
325+ int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
326+ perPhase = std::max<int>(perPhase, 1);
327+ std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
328+ int vecWidth = 32 / typeWidthInBit;
329+ if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) {
330+ perPhase = std::max<int>(perPhase, 2 * vecWidth);
331+ }
332+ int rank = order.size();
333+ // --- handle A operand ---
334+ if (opIdx == 0) { // compute swizzling for A operand
335+ int m = (needTrans) ? matShape[2] : matShape[0];
336+ int k = (needTrans) ? matShape[0] : matShape[2];
337+ int vec = (order[0] == rank-1) ? k : m;
338+ int mmaStride = (order[0] == rank-1) ? m : k;
339+ int maxPhase = std::max(mmaStride / perPhase, 1);
340+ return get(context, vec, perPhase, maxPhase, order, CTALayout);
341+ }
342+
343+ // --- handle B operand ---
344+ if (opIdx == 1) {
345+ // we compute vec and maxPhase m, n and k size of the mma
346+ // instruction. when matmul operands is transposed, we should
347+ // consider that to get m, n and k.
348+ int n = needTrans ? matShape[2] : matShape[1];
349+ int k = needTrans ? matShape[1] : matShape[2];
350+ int vec = (order[0] == rank-1) ? n : k;
351+ int mmaStride = (order[0] == rank-1) ? k : n;
352+ int maxPhase = std::max(mmaStride / perPhase, 1);
353+ return get(context, vec, perPhase, maxPhase, order, CTALayout);
354+ }
355+
356+ llvm_unreachable("invalid operand index");
315357 }
316358
317359 // ---- not implemented ----
318360 llvm_unreachable("unsupported swizzling for provided MMA version");
319361 }]>,
320362
321- // NVIDIA constructor!
322- // TODO(lezcano): We should totally get rid of all these constructors...
323- AttrBuilder<(ins "int":$opIdx,
324- "unsigned":$kWidth,
325- "ArrayRef<int64_t>":$shape,
326- "ArrayRef<unsigned>":$order,
327- "CTALayoutAttr":$CTALayout,
328- "unsigned":$bitwidth,
329- "bool":$needTrans), [{
330- int K = getShapePerCTA(CTALayout.getCTASplitNum(), shape)[order[0]];
331- // Elems necessary to cover all the banks divided by the inner dimension
332- // This packs a few rows together for small K
333- int perPhase = std::max<int>(1024 / (bitwidth * K), 1);
334-
335- int mmaStride = 8;
336- int vec = 4 * kWidth;
337- // needsTrans is equiv. to flipping the opIdx
338- if (needTrans)
339- std::swap(vec, mmaStride);
340- assert(opIdx == 0 || opIdx == 1);
341- int rank = order.size();
342- int kDim = opIdx == 0 ? rank-1 : rank-2;
343- if (order[0] != kDim)
344- std::swap(vec, mmaStride);
345- // Count how many vec elements are needed to cover all the banks
346- int maxPhase = std::max(std::min<int>(mmaStride, 1024 / (vec * bitwidth)), 1);
347- // Account for the row packing from perPhase: mmaStride / perPhase
348- maxPhase = std::max(maxPhase / perPhase, 1);
349- return get(context, vec, perPhase, maxPhase, order, CTALayout);
350- }]>,
351-
352363 AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
353364 "ArrayRef<int64_t>":$shape,
354365 "ArrayRef<unsigned>":$order,
@@ -387,6 +398,8 @@ def NVMMASharedEncodingAttr :
387398 This is meant to represent 2d tiled blocked layout.
388399 The full layout representation is described here:
389400 https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout
401+ When the memdesc has more than 2 dimensions the tiling is applied to 8 rows even if the first outer dimension is smaller than 8.
402+ In this case `transposed` means that the contiguous dimension is the most outer dimension of the memdesc.
390403 }];
391404
392405
0 commit comments