@@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false.
361361 return get(context, vec, perPhase, maxPhase, order, CTALayout);
362362 }
363363
364- // ---- begin Ampere ----
365- if (mmaEnc.isAmpere()) {
364+ // ---- begin Ampere & Hopper ----
365+ if (mmaEnc.isAmpere() || mmaEnc.isHopper() ) {
366366 int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
367367 perPhase = std::max<int>(perPhase, 1);
368368 std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
@@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false.
397397 llvm_unreachable("invalid operand index");
398398 }
399399
400- // ---- begin version 3 ----
401- if (mmaEnc.isHopper()) {
402- llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
403- " is Hopper has not been implemented yet");
404- return $_get(context, 1, 1, 1, order, CTALayout, true);
405- }
406-
407400 // ---- not implemented ----
408401 llvm_unreachable("unsupported swizzling for provided MMA version");
409402 }]>,
@@ -1237,7 +1230,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12371230 SmallVector<int> getMMAv1Rep(int opIdx) const;
12381231 SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12391232 int getMMAv1Vec(int opIdx) const;
1240- SmallVector<int64_t> getMMAv2RepForOperand (ArrayRef<int64_t> shape,
1233+ SmallVector<int64_t> getMMAv2OrV3RepForOperand (ArrayRef<int64_t> shape,
12411234 int bitwidth, int kWidth, int opIdx) const;
12421235
12431236 bool supportReduction() const {
@@ -1336,6 +1329,27 @@ The parent field is the layout of d.
13361329kWidth defines number of consecutive elements stored by one thread along k dimension.
13371330Some layouts do not use this parameter, either because they have a fixed number of
13381331elements along the K dim, or they use all elements of the tensor along the K dim.
1332+
1333+ # WGMMA Notes
1334+ We require kWidth to be provided for Hopper because the dtype at loading might be
1335+ different from the dtype at WGMMA, due to casting. The kWidth is determined by the
1336+ dtype at WGMMA.
1337+
1338+ The encoded tensor consists of operand A for possibly multiple wgmma instructions.
1339+ For each wgmma, each warp in a warp group feeds a single "warp matrix"
1340+ Each warp matrix consists of 2x2 "quads".
1341+ Each thread holds several elements in each quad. Right before a wgmma,
1342+ the sum of bitwidth of
1343+ the elements in each quad should add up to 32.
1344+
1345+ These values are stored unrolled in `elements`.
1346+ The ordering of dimensions is as follows by convention:
1347+ batch (only 1 batch for Hopper currently)
1348+ matM (m-index of the "warp matrix")
1349+ matK (k-index of the "warp matrix")
1350+ quadK (k-index of the "quad" in the core matrix)
1351+ quadM (m-index of the "quad" in the core matrix)
1352+ vecIdx (index of the element in the quad; this is always along the k-dim)
13391353 }];
13401354
13411355 let parameters = (
@@ -1346,16 +1360,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13461360 );
13471361
13481362 let builders = [
1349- // Specially for MMAV1(Volta)
13501363 AttrBuilder<(ins "unsigned":$opIdx,
13511364 "Attribute":$parent,
13521365 "Type":$eltTy), [{
13531366 NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
1354- if (!parentAttr || !parentAttr.isAmpere())
1355- return $_get(context, opIdx, parent, 0);
1367+ if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper()))
1368+ return $_get(context, opIdx, parent, 0); // For MMAV1
1369+ // For MMAV2 and V3
13561370 unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
1357- unsigned MMAv2kWidth = 32 / bitwidth;
1358- return $_get(context, opIdx, parent, MMAv2kWidth );
1371+ unsigned kWidth = 32 / bitwidth;
1372+ return $_get(context, opIdx, parent, kWidth );
13591373 }]>
13601374 ];
13611375
0 commit comments