Skip to content

Commit 49a52a2

Browse files
Merge commit '4f6f76874ff623562903d5452d499cae3d40d448'
2 parents 1442ff4 + 4f6f768 commit 49a52a2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1116
-722
lines changed

docs/python-api/triton.language.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Linear Algebra Ops
5959
:nosignatures:
6060

6161
dot
62+
dot_scaled
6263

6364

6465
Memory/Pointer Ops

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ bool atomicNeedsSharedMemory(Value result);
214214

215215
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
216216

217-
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218-
219217
// Return true if the src and dst layout match.
220218
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
221219
RankedTensorType dstTy);

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr<
119119
let cppNamespace = "::mlir::triton";
120120
}
121121

122-
// Type for F8F6F4 kind of floats.
123-
def TT_F8F6F4TypeAttr : I32EnumAttr<
124-
"F8F6F4Type", "",
122+
// Type for ScaleDotElemType kind of floats.
123+
def TT_ScaleDotElemTypeAttr : I32EnumAttr<
124+
"ScaleDotElemType", "",
125125
[
126126
I32EnumAttrCase<"E4M3", 0, "e4m3">,
127127
I32EnumAttrCase<"E5M2", 1, "e5m2">,
128128
I32EnumAttrCase<"E2M3", 2, "e2m3">,
129129
I32EnumAttrCase<"E3M2", 3, "e3m2">,
130-
I32EnumAttrCase<"E2M1", 4, "e2m1">
130+
I32EnumAttrCase<"E2M1", 4, "e2m1">,
131+
I32EnumAttrCase<"BF16", 5, "bf16">
131132

132133
]>{
133134
let cppNamespace = "::mlir::triton";

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
685685

686686
let arguments = (
687687
ins
688-
// inputs are integer types as they are packed types and we currently
689-
// don't have a representation for those.
690-
TT_IntTensor:$lhs,
691-
TT_IntTensor:$rhs,
688+
// inputs are floats if we have a type for them, otherwise (fp4),
689+
// they are packed in pairs in an I8Tensor
690+
RankedTensorOf<[TT_Float,I8]>:$lhs,
691+
RankedTensorOf<[TT_Float,I8]>:$rhs,
692692
TT_FloatTensor:$c,
693-
TT_IntTensor:$lhs_scale,
694-
Optional<TT_IntTensor>:$rhs_scale,
695-
TT_F8F6F4TypeAttr:$lhs_type,
696-
TT_F8F6F4TypeAttr:$rhs_type
693+
RankedTensorOf<[I8]>:$lhs_scale,
694+
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
695+
TT_ScaleDotElemTypeAttr:$lhs_type,
696+
TT_ScaleDotElemTypeAttr:$rhs_type
697697
);
698698

699699
let results = (outs TT_FloatTensor:$d);

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false.
361361
return get(context, vec, perPhase, maxPhase, order, CTALayout);
362362
}
363363

364-
// ---- begin Ampere ----
365-
if (mmaEnc.isAmpere()) {
364+
// ---- begin Ampere & Hopper ----
365+
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
366366
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
367367
perPhase = std::max<int>(perPhase, 1);
368368
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
@@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false.
397397
llvm_unreachable("invalid operand index");
398398
}
399399

400-
// ---- begin version 3 ----
401-
if (mmaEnc.isHopper()) {
402-
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
403-
" is Hopper has not been implemented yet");
404-
return $_get(context, 1, 1, 1, order, CTALayout, true);
405-
}
406-
407400
// ---- not implemented ----
408401
llvm_unreachable("unsupported swizzling for provided MMA version");
409402
}]>,
@@ -1237,7 +1230,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12371230
SmallVector<int> getMMAv1Rep(int opIdx) const;
12381231
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12391232
int getMMAv1Vec(int opIdx) const;
1240-
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
1233+
SmallVector<int64_t> getMMAv2OrV3RepForOperand(ArrayRef<int64_t> shape,
12411234
int bitwidth, int kWidth, int opIdx) const;
12421235

12431236
bool supportReduction() const {
@@ -1336,6 +1329,27 @@ The parent field is the layout of d.
13361329
kWidth defines number of consecutive elements stored by one thread along k dimension.
13371330
Some layouts do not use this parameter, either because they have a fixed number of
13381331
elements along the K dim, or they use all elements of the tensor along the K dim.
1332+
1333+
# WGMMA Notes
1334+
We require kWidth to be provided for Hopper because the dtype at loading might be
1335+
different from the dtype at WGMMA, due to casting. The kWidth is determined by the
1336+
dtype at WGMMA.
1337+
1338+
The encoded tensor consists of operand A for possibly multiple wgmma instructions.
1339+
For each wgmma, each warp in a warp group feeds a single "warp matrix"
1340+
Each warp matrix consists of 2x2 "quads".
1341+
Each thread holds several elements in each quad. Right before a wgmma,
1342+
the sum of bitwidth of
1343+
the elements in each quad should add up to 32.
1344+
1345+
These values are stored unrolled in `elements`.
1346+
The ordering of dimensions is as follows by convention:
1347+
batch (only 1 batch for Hopper currently)
1348+
matM (m-index of the "warp matrix")
1349+
matK (k-index of the "warp matrix")
1350+
quadK (k-index of the "quad" in the core matrix)
1351+
quadM (m-index of the "quad" in the core matrix)
1352+
vecIdx (index of the element in the quad; this is always along the k-dim)
13391353
}];
13401354

13411355
let parameters = (
@@ -1346,16 +1360,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13461360
);
13471361

13481362
let builders = [
1349-
// Specially for MMAV1(Volta)
13501363
AttrBuilder<(ins "unsigned":$opIdx,
13511364
"Attribute":$parent,
13521365
"Type":$eltTy), [{
13531366
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
1354-
if (!parentAttr || !parentAttr.isAmpere())
1355-
return $_get(context, opIdx, parent, 0);
1367+
if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper()))
1368+
return $_get(context, opIdx, parent, 0); // For MMAV1
1369+
// For MMAV2 and V3
13561370
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
1357-
unsigned MMAv2kWidth = 32 / bitwidth;
1358-
return $_get(context, opIdx, parent, MMAv2kWidth);
1371+
unsigned kWidth = 32 / bitwidth;
1372+
return $_get(context, opIdx, parent, kWidth);
13591373
}]>
13601374
];
13611375

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
268268
let arguments = (ins
269269
TT_Tensor:$src,
270270
TT_Tensor:$scale,
271-
TT_F8F6F4TypeAttr:$fp_type);
271+
TT_ScaleDotElemTypeAttr:$fp_type);
272272
let results = (outs TT_Tensor:$result);
273273

274274
let assemblyFormat = [{

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
113113
Attribute srcLayout = srcTy.getEncoding();
114114
Attribute dstLayout = dstTy.getEncoding();
115115

116-
assert(!isMfmaToDotShortcut(srcTy, dstTy));
116+
assert(cvtNeedsSharedMemory(srcTy, dstTy));
117117

118118
// FIXME This is NOT entirely correct
119119
// This should be getElemOrder, but we don't have such a method

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -612,22 +612,6 @@ bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
612612
return matrixDimsCompatible && bDimCompatible;
613613
}
614614

615-
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
616-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
617-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
618-
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
619-
return false;
620-
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
621-
// improved. In addition, we can enable this shortcut for regular MFMA
622-
// layout when opIdx == 1.
623-
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
624-
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
625-
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
626-
dotOperandLayout.getParent() == mfmaLayout &&
627-
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
628-
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
629-
}
630-
631615
// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
632616
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
633617
RankedTensorType dstTy) {
@@ -708,8 +692,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
708692
return !cvtReordersRegisters(srcTy, dstTy) &&
709693
!triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) &&
710694
!isBlockedToDotShortcut(srcTy, dstTy) &&
711-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
712-
!isMfmaToDotShortcut(srcTy, dstTy);
695+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
713696
}
714697

715698
bool atomicNeedsSharedMemory(Value value) {

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,25 @@ using namespace mlir::triton::gpu;
1111

1212
namespace mlir::triton::gpu {
1313

14+
namespace {
15+
16+
bool isDotOpTensorAndPacked(Type srcTy) {
17+
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
18+
if (!tensorTy)
19+
return false;
20+
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
21+
if (!encoding)
22+
return false;
23+
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
24+
// By code convention, values for Hopper's dotOp-encoded tensors are not
25+
// packed
26+
if (!parentEnc || parentEnc.isHopper())
27+
return false;
28+
return true;
29+
}
30+
31+
} // namespace
32+
1433
Type getElementType(Value value) {
1534
auto type = value.getType();
1635
if (auto tensorType = dyn_cast<RankedTensorType>(type))
@@ -33,7 +52,7 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
3352
// If the parent of the dot operand is in block encoding, we don't need to
3453
// reorder elements
3554
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
36-
if (!parentEncoding)
55+
if (!parentEncoding || parentEncoding.isHopper())
3756
return values;
3857
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
3958
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,13 +1099,18 @@ LogicalResult DotOperandEncodingAttr::verify(
10991099
return emitError() << "triton_gpu.dot_op parent paramenter cannot be null";
11001100
}
11011101
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
1102-
if (kWidth != 0 && !parentAttr.isAmpere())
1102+
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
11031103
return emitError() << "triton_gpu.dot_op kWidth parameter can only be "
1104-
"non-zero for Ampere MMA parent";
1105-
if (kWidth == 0 && parentAttr.isAmpere())
1104+
"non-zero for Ampere or Hopper MMA parent";
1105+
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
11061106
return emitError()
11071107
<< "triton_gpu.dot_op kWidth parameter is mandatory for "
1108-
"Ampere MMA parent";
1108+
"Ampere or Hopper MMA parent";
1109+
if (opIdx != 0 && parentAttr.isHopper())
1110+
return emitError()
1111+
<< "triton_gpu.dot_op opIdx parameter must be 0 for "
1112+
"Hopper MMA parent, since Hopper WGMMA only allows first "
1113+
"operand to be in registers";
11091114
return success();
11101115
}
11111116

@@ -2053,17 +2058,20 @@ SmallVector<int> NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const {
20532058
int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const {
20542059
return 2 * getMMAv1Rep(opIdx)[opIdx];
20552060
}
2056-
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
2061+
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2OrV3RepForOperand(
20572062
ArrayRef<int64_t> shape, int bitwidth, int kWidth, int opIdx) const {
2063+
assert(isAmpere() || (isHopper() && opIdx == 0));
20582064
auto rank = shape.size();
20592065
auto warpsPerCTA = getWarpsPerCTA();
20602066

2067+
// {batch, m, n, k}
2068+
// Hopper path never uses the n value, since this method is only invoked
2069+
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
20612070
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
20622071
int numRepBatch =
20632072
rank == 3
20642073
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
20652074
: 1;
2066-
assert(isAmpere());
20672075

20682076
if (opIdx == 0)
20692077
return {numRepBatch,
@@ -2078,19 +2086,26 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
20782086
warpsPerCTA[rank - 1]))};
20792087
}
20802088
}
2089+
20812090
unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
20822091
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
20832092
auto shapePerCTA = getShapePerCTA(*this, shape);
20842093
int warpsPerCTAM = getWarpsPerCTA()[0];
20852094
int warpsPerCTAN = getWarpsPerCTA()[1];
20862095
// H100
20872096
if (isHopper()) {
2088-
return getTotalElemsPerThread(shape, eltTy);
2097+
assert(opIdx == 0);
2098+
auto instrMNK = getInstrShape();
2099+
int repM = ceil<unsigned>(shapePerCTA[0], instrMNK[0] * warpsPerCTAM);
2100+
int repK = ceil<unsigned>(shapePerCTA[1], instrMNK[2]);
2101+
// For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds
2102+
// kWidth elements for each quadrant. WGMMA is repeated repM * repK times.
2103+
return 4 * kWidth * repM * repK;
20892104
}
20902105
// A100
20912106
if (isAmpere()) {
2092-
auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(),
2093-
kWidth, opIdx);
2107+
auto rep = getMMAv2OrV3RepForOperand(
2108+
shapePerCTA, eltTy.getIntOrFloatBitWidth(), kWidth, opIdx);
20942109
if (opIdx == 0)
20952110
return 4 * rep[0] * rep[1] * rep[2];
20962111
if (opIdx == 1)

0 commit comments

Comments
 (0)