Skip to content

Commit 9e90089

Browse files
authored
[Backend] Implement scaled_dot(mxfp4, fp8) (#4904)
This PR includes triton-lang/triton#4891 and triton-lang/triton#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang/triton#4895 is the first step in this direction.
1 parent 93de426 commit 9e90089

File tree

15 files changed

+317
-119
lines changed

15 files changed

+317
-119
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,6 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
250250
ArrayRef<unsigned> repShape,
251251
ArrayRef<unsigned> paddedRepShape,
252252
ArrayRef<unsigned> order, int swizzleByteSize);
253-
254-
// FIXME
255-
// Exposing to use it in LinearLayoutConversionsTest.cpp
256-
// Remove it once we fully activate the DotOperand conversion via LLs
257-
class DotOperandEncodingAttr;
258-
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
259-
DotOperandEncodingAttr dot);
260253
} // namespace mlir::triton::gpu
261254

262255
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Analysis/Allocation.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
115115

116116
assert(!isMfmaToDotShortcut(srcTy, dstTy));
117117

118-
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
118+
// FIXME This is NOT entirely correct
119+
// This should be getElemOrder, but we don't have such a method
120+
// TODO Implement getElemOrder and make sure it's consistent with
121+
// getContigPerThread
122+
auto inOrd = gpu::getThreadOrder(srcLayout);
123+
auto outOrd = gpu::getThreadOrder(dstLayout);
119124
scratchConfig.order = outOrd;
120125

121126
unsigned srcContigPerThread =

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
404404
}
405405
return true;
406406
}
407+
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
408+
if (auto nvidiaMma =
409+
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
410+
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
411+
return false;
412+
}
413+
if (useLegacyMMAConversion) {
414+
return false;
415+
}
416+
// FIXME [Dot LL]
417+
// Enabling LL path for buggy kWidth path
418+
bool largeKWidth =
419+
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
420+
return largeKWidth && nvidiaMma.isAmpere();
421+
}
422+
}
407423
if (isa<BlockedEncodingAttr>(layout)) {
408424
return true;
409425
}
@@ -460,6 +476,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
460476
}
461477
}
462478

479+
// FIXME [Dot LL]
480+
// We know it's just for largeKWidth case in Ampere
481+
// In this case, we need to pack the outputs into i32
482+
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
483+
auto concat = [&](Value a, Value b) {
484+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
485+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
486+
};
487+
488+
SmallVector<Value> outVals32(outVals.size() / 2);
489+
for (int i = 0; i < outVals32.size(); ++i) {
490+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
491+
}
492+
outVals = outVals32;
493+
}
494+
463495
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
464496
op.getType());
465497
rewriter.replaceOp(op, result);

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9090
auto dstDotOp =
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
93+
// FIXME [Dot LL]
94+
// We support this one via LLs, as the LocalLoad path is buggy
95+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
96+
bool largeKWidth =
97+
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
98+
if (mma.isAmpere() && largeKWidth) {
99+
return;
100+
}
101+
}
102+
93103
Attribute sharedMemorySpace =
94104
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
95105
auto tmpType = MemDescType::get(

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "mlir/Support/LLVM.h"
99
#include "triton/Analysis/Utility.h"
1010
#include "triton/Dialect/Triton/IR/Utility.h"
11+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1112
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1213
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1314
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -234,8 +235,31 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
234235
return resOrder;
235236
}
236237

238+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
239+
bool kMajor) {
240+
// kMajor: if true, the matrix is fastest-running on k,
241+
// otherwise it is on m (resp. n)
242+
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
243+
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
244+
// batch (if rank == 3) is always the slowest running dimension
245+
assert(rank == 2 || rank == 3);
246+
assert(opIdx == 0 || opIdx == 1);
247+
SmallVector<unsigned> order(rank);
248+
std::iota(order.rbegin(), order.rend(), 0);
249+
// If opIdx is 1 and kMajor is true, the order is [0, 1]
250+
// (resp. [1, 2, 0] if rank == 3)
251+
// Same if opIdx is 0 and kMajor is false
252+
if (bool(opIdx) == kMajor) {
253+
std::swap(order[0], order[1]);
254+
}
255+
return order;
256+
}
257+
237258
SmallVector<unsigned> getWarpOrder(Attribute layout) {
238259
auto order = getOrder(layout);
260+
// FIXME: This mmaLayout if should just return
261+
// getOrderForDotOperand(0, order.size(), kMajor=false)
262+
// as mma has the same order as DotOperand(opIdx=0)
239263
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
240264
if (mmaLayout.isHopper()) {
241265
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
@@ -245,40 +269,8 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
245269
order.insert(order.begin(), 0);
246270
}
247271
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
248-
// opIdx=0: [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k] -> order=[1, 2, 0]
249-
// opIdx=1: [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n] -> order=[2, 1, 0]
250-
std::iota(order.rbegin(), order.rend(), 0);
251-
if (dotOpLayout.getOpIdx() == 0) {
252-
std::swap(order[0], order[1]);
253-
}
254-
}
255-
return order;
256-
}
257-
258-
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) {
259-
assert((rank == 2 || rank == 3) &&
260-
"Invalid rank for dot operand order computation");
261-
SmallVector<unsigned> order(rank);
262-
// The 'order' field typically represents a descending sorted array of
263-
// dimensions based on contiguity. For instance, in axisInfo utilities that
264-
// retrieve tensor contiguity, it's assumed that the dimension with the
265-
// highest contiguity corresponds to order[0].
266-
//
267-
// The relation between contiguity and order is only relevant if the layout
268-
// interfaces with HBM, as is the case when we load tensor from HBM to
269-
// registers in the dot layout to bypass LDS. When bypassing LDS, we make
270-
// the following assumptions about tensor layouts:
271-
// - Tensor A (opIdx == 0) is considered to be row-major.
272-
// - Tensor B (opIdx == 1) is considered to be column-major.
273-
//
274-
// Based on these assumptions, we define the following orders:
275-
// - For opIdx == 0, batch=dim0, m=dim1, and k=dim2, we assume an order of [2,
276-
// 1, 0] for 3D tensors.
277-
// - For opIdx == 1, batch=dim0, k=dim1, and n=dim2, we assume an order of [1,
278-
// 2, 0] for 3D tensors.
279-
std::iota(order.rbegin(), order.rend(), 0);
280-
if (opIdx == 1) {
281-
std::swap(order[0], order[1]);
272+
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
273+
/*kMajor*/ false);
282274
}
283275
return order;
284276
}
@@ -295,8 +287,8 @@ SmallVector<unsigned> getOrder(Attribute layout) {
295287
return order;
296288
}
297289
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
298-
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
299-
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
290+
auto rank = dotLayout.getWarpsPerCTA().size();
291+
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
300292
}
301293
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
302294
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
@@ -1048,7 +1040,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10481040
return ::getWarpOrder(*this);
10491041
}
10501042
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
1051-
return ::getOrder(*this);
1043+
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1044+
/*kMajor*/ true);
10521045
}
10531046
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
10541047
ArrayRef<int64_t> tensorShape) const {
@@ -2019,6 +2012,7 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
20192012
ArrayRef<int64_t> shape, int bitwidth, int kWidth, int opIdx) const {
20202013
auto rank = shape.size();
20212014
auto warpsPerCTA = getWarpsPerCTA();
2015+
20222016
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
20232017
int numRepBatch =
20242018
rank == 3

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -886,13 +886,14 @@ std::optional<LinearLayout>
886886
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
887887
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
888888
return dotOperandMfmaToLinearLayout(*this, shape);
889+
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
890+
// FIXME [Dot LL]
891+
// Do this unconditionally
892+
auto largeKWidth = getKWidth() == 8;
893+
if (mma.isAmpere() && largeKWidth) {
894+
return ampereDotToLinearLayout(shape, *this);
895+
}
889896
}
890-
// TODO Activate in a follow-up PR
891-
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
892-
// if (mma.isAmpere()) {
893-
// return ampereDotToLinearLayout(shape, *this);
894-
// }
895-
//}
896897
return std::nullopt;
897898
}
898899

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ LogicalResult UpcastMXFPOp::verify() {
1717
auto xTy = getSrc().getType();
1818
auto scaleTy = getScale().getType();
1919

20-
if (xTy.getElementType() != FloatType::getBF16(getContext())) {
21-
return emitOpError("element type of the first operand must be bf16");
20+
if (xTy.getElementType() != FloatType::getBF16(getContext()) &&
21+
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
22+
return emitOpError("element type of the first operand must be bf16 or i8");
2223
}
2324

2425
if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
@@ -72,7 +73,7 @@ LogicalResult UpcastMXFPOp::verify() {
7273
}
7374

7475
LogicalResult UpcastMXFPOp::inferReturnTypes(
75-
MLIRContext *context, std::optional<Location> location, ValueRange operands,
76+
MLIRContext *ctx, std::optional<Location> loc, ValueRange operands,
7677
DictionaryAttr attributes, OpaqueProperties opaqueProperties,
7778
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
7879
auto xTy = cast<RankedTensorType>(operands[0].getType());
@@ -82,21 +83,25 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
8283

8384
auto encoding = xTy.getEncoding();
8485
if (!encoding) {
85-
return emitOptionalError(location, "expected an encoding");
86+
return emitOptionalError(loc, "expected an encoding");
8687
}
8788
if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
88-
return emitOptionalError(location, "expected an mma layout encoding");
89-
}
90-
if (xShape.size() < 2) {
91-
return emitOptionalError(location, "tensor rank must be at least 2");
89+
return emitOptionalError(loc, "expected a dotOperand encoding");
9290
}
9391

94-
// For now we just return the input encoding. For fp4 we'll need to cast from
95-
// tf32 to fp16 encoding and multiply the shape by two
96-
assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) &&
97-
"NYI: only fp8e4m3 and fp8e5m2 are supported");
92+
if (typeEncoded == F8F6F4Type::E2M1) {
93+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
94+
auto newVEncoding = DotOperandEncodingAttr::get(
95+
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
96+
oldEncoding.getKWidth() * 2);
97+
auto newShape = SmallVector<int64_t>(xShape);
98+
newShape.back() *= 2;
99+
inferredReturnTypes.push_back(
100+
RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding));
101+
} else {
102+
inferredReturnTypes.push_back(xTy);
103+
}
98104

99-
inferredReturnTypes.push_back(xTy);
100105
return success();
101106
}
102107

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class ScaledBlockedToMMAv2
406406
auto ctx = dotOp.getContext();
407407

408408
// Check that rhs scale is null
409-
assert(dotOp.getRhsScale() == nullptr && "rhs scale must be null");
409+
assert(dotOp.getRhsScale() == nullptr && "rhs scale NYI");
410410

411411
// operands
412412
auto a = dotOp.getLhs();
@@ -426,10 +426,11 @@ class ScaledBlockedToMMAv2
426426
}
427427
};
428428

429-
assert(aType == F8F6F4Type::E4M3 ||
430-
aType == F8F6F4Type::E5M2 && "lhs just supports fp8");
429+
assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 ||
430+
aType == F8F6F4Type::E2M1) &&
431+
"NYI: lhs supports fp4 or fp8");
431432
assert(bType == F8F6F4Type::E4M3 ||
432-
bType == F8F6F4Type::E5M2 && "rhs just supports fp8");
433+
bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8");
433434

434435
// TODO run accelerate matmul on A and B first to choose their layouts
435436
// Set return type
@@ -440,6 +441,7 @@ class ScaledBlockedToMMAv2
440441
auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA,
441442
rewriter.getBF16Type(), numWarps);
442443
auto CTALayout = getCTALayout(oldRetType.getEncoding());
444+
// TODO Use warpsPerTileV2
443445
SmallVector<unsigned> warpsPerCTA = {numWarps, 1};
444446
auto mmaEnc = NvidiaMmaEncodingAttr::get(ctx, /*versionMajor=*/versionMajor,
445447
/*versionMinor=*/0, warpsPerCTA,
@@ -452,27 +454,39 @@ class ScaledBlockedToMMAv2
452454
auto newAcc =
453455
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);
454456

455-
auto toMMABf16 = [&newRetType, &rewriter, &ctx,
456-
&enumToType](TypedValue<RankedTensorType> v, int idx,
457-
F8F6F4Type type) {
458-
// MMAv2 Layout
457+
auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType](
458+
TypedValue<RankedTensorType> v, int idx,
459+
F8F6F4Type type) -> TypedValue<RankedTensorType> {
459460
auto vType = v.getType();
460-
auto newVEncoding = DotOperandEncodingAttr::get(
461-
ctx, idx, newRetType.getEncoding(), enumToType((type)));
462-
auto newVType = RankedTensorType::get(
463-
v.getType().getShape(), v.getType().getElementType(), newVEncoding);
464-
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
465-
466-
// Bitcast
467-
auto vTypeFp8 = RankedTensorType::get(
468-
vType.getShape(), rewriter.getFloat8E4M3FNType(), newVEncoding);
469-
v = cast<TypedValue<RankedTensorType>>(
470-
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());
471-
472-
// Convert to bf16
473-
auto vTypeBf16 = RankedTensorType::get(
474-
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
475-
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
461+
if (type == F8F6F4Type::E2M1) {
462+
// A bit too dynamically typed...
463+
// perhaps return ints in both cases?
464+
465+
auto retEnc = dyn_cast<NvidiaMmaEncodingAttr>(newRetType.getEncoding());
466+
auto newVEncoding = DotOperandEncodingAttr::get(
467+
ctx, idx, newRetType.getEncoding(), /*kWidth=*/4);
468+
auto newVType = RankedTensorType::get(
469+
vType.getShape(), vType.getElementType(), newVEncoding);
470+
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
471+
} else {
472+
assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3);
473+
auto newVEncoding = DotOperandEncodingAttr::get(
474+
ctx, idx, newRetType.getEncoding(), /*kWidth=*/8);
475+
auto newVType = RankedTensorType::get(
476+
vType.getShape(), vType.getElementType(), newVEncoding);
477+
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
478+
479+
// Bitcast
480+
auto vTypeFp8 = RankedTensorType::get(vType.getShape(),
481+
enumToType(type), newVEncoding);
482+
v = cast<TypedValue<RankedTensorType>>(
483+
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());
484+
485+
// Convert to bf16
486+
auto vTypeBf16 = RankedTensorType::get(
487+
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
488+
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
489+
}
476490
};
477491
a = toMMABf16(a, 0, aType);
478492
b = toMMABf16(b, 1, bType);

lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ class TritonGPUReduceDataDuplicationPass
4444
return;
4545
if (!cvtNeedsSharedMemory(srcType, dstType))
4646
return;
47+
// FIXME [Dot LL]
48+
// We support this one via LLs, as the LocalLoad path is buggy
49+
bool largeKWidth =
50+
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
51+
if (largeKWidth) {
52+
return;
53+
}
4754
auto srcOrder = triton::gpu::getOrder(srcEncoding);
4855
auto rank = srcOrder.size();
4956
SmallVector<unsigned> sharedOrder;

0 commit comments

Comments
 (0)