Skip to content

Commit a626ab8

Browse files
Revert "[Backend] Implement scaled_dot(mxfp4, fp8) (#4904)"
This reverts commit 9e90089.
1 parent c44a95b commit a626ab8

File tree

15 files changed

+103
-392
lines changed

15 files changed

+103
-392
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
250250
ArrayRef<unsigned> repShape,
251251
ArrayRef<unsigned> paddedRepShape,
252252
ArrayRef<unsigned> order, int swizzleByteSize);
253+
254+
// FIXME
255+
// Exposing to use it in LinearLayoutConversionsTest.cpp
256+
// Remove it once we fully activate the DotOperand conversion via LLs
257+
class DotOperandEncodingAttr;
258+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
259+
DotOperandEncodingAttr dot);
253260
} // namespace mlir::triton::gpu
254261

255262
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Analysis/Allocation.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
115115

116116
assert(!isMfmaToDotShortcut(srcTy, dstTy));
117117

118-
// FIXME This is NOT entirely correct
119-
// This should be getElemOrder, but we don't have such a method
120-
// TODO Implement getElemOrder and make sure it's consistent with
121-
// getContigPerThread
122-
auto inOrd = gpu::getThreadOrder(srcLayout);
123-
auto outOrd = gpu::getThreadOrder(dstLayout);
118+
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
124119
scratchConfig.order = outOrd;
125120

126121
unsigned srcContigPerThread =

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -404,22 +404,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
404404
}
405405
return true;
406406
}
407-
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
408-
if (auto nvidiaMma =
409-
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
410-
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
411-
return false;
412-
}
413-
if (useLegacyMMAConversion) {
414-
return false;
415-
}
416-
// FIXME [Dot LL]
417-
// Enabling LL path for buggy kWidth path
418-
bool largeKWidth =
419-
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
420-
return largeKWidth && nvidiaMma.isAmpere();
421-
}
422-
}
423407
if (isa<BlockedEncodingAttr>(layout)) {
424408
return true;
425409
}
@@ -476,22 +460,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
476460
}
477461
}
478462

479-
// FIXME [Dot LL]
480-
// We know it's just for largeKWidth case in Ampere
481-
// In this case, we need to pack the outputs into i32
482-
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
483-
auto concat = [&](Value a, Value b) {
484-
return or_(zext(i32_ty, bitcast(a, i16_ty)),
485-
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
486-
};
487-
488-
SmallVector<Value> outVals32(outVals.size() / 2);
489-
for (int i = 0; i < outVals32.size(); ++i) {
490-
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
491-
}
492-
outVals = outVals32;
493-
}
494-
495463
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
496464
op.getType());
497465
rewriter.replaceOp(op, result);

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,6 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
9090
auto dstDotOp =
9191
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
9292
if (srcBlocked && dstDotOp) {
93-
// FIXME [Dot LL]
94-
// We support this one via LLs, as the LocalLoad path is buggy
95-
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
96-
bool largeKWidth =
97-
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
98-
if (mma.isAmpere() && largeKWidth) {
99-
return;
100-
}
101-
}
102-
10393
Attribute sharedMemorySpace =
10494
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
10595
auto tmpType = MemDescType::get(

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "mlir/Support/LLVM.h"
1212
#include "triton/Analysis/Utility.h"
1313
#include "triton/Dialect/Triton/IR/Utility.h"
14-
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1514
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1615
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1716
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
@@ -238,31 +237,8 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
238237
return resOrder;
239238
}
240239

241-
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
242-
bool kMajor) {
243-
// kMajor: if true, the matrix is fastest-running on k,
244-
// otherwise it is on m (resp. n)
245-
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
246-
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
247-
// batch (if rank == 3) is always the slowest running dimension
248-
assert(rank == 2 || rank == 3);
249-
assert(opIdx == 0 || opIdx == 1);
250-
SmallVector<unsigned> order(rank);
251-
std::iota(order.rbegin(), order.rend(), 0);
252-
// If opIdx is 1 and kMajor is true, the order is [0, 1]
253-
// (resp. [1, 2, 0] if rank == 3)
254-
// Same if opIdx is 0 and kMajor is false
255-
if (bool(opIdx) == kMajor) {
256-
std::swap(order[0], order[1]);
257-
}
258-
return order;
259-
}
260-
261240
SmallVector<unsigned> getWarpOrder(Attribute layout) {
262241
auto order = getOrder(layout);
263-
// FIXME: This mmaLayout if should just return
264-
// getOrderForDotOperand(0, order.size(), kMajor=false)
265-
// as mma has the same order as DotOperand(opIdx=0)
266242
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
267243
if (mmaLayout.isHopper()) {
268244
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
@@ -271,9 +247,30 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
271247
order.erase(it);
272248
order.insert(order.begin(), 0);
273249
}
274-
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
275-
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
276-
/*kMajor*/ false);
250+
}
251+
return order;
252+
}
253+
254+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) {
255+
SmallVector<unsigned> order(rank);
256+
// The 'order' field typically represents a descending sorted array of
257+
// dimensions based on contiguity. For instance, in axisInfo utilities that
258+
// retrieve tensor contiguity, it's assumed that the dimension with the
259+
// highest contiguity corresponds to order[0].
260+
//
261+
// The relation between contiguity and order is only relevant if the layout
262+
// interfaces with HBM, as is the case when we load tensor from HBM to
263+
// registers in the dot layout to bypass LDS. When bypassing LDS, we make the
264+
// following assumptions about tensor layouts:
265+
// - Tensor A (opIdx == 0) is considered to be row-major.
266+
// - Tensor B (opIdx == 1) is considered to be column-major.
267+
//
268+
// Based on these assumptions, we define the following orders:
269+
// - For opIdx == 0, we assume an order of [1, 0].
270+
// - For opIdx == 1, we assume an order of [0, 1].
271+
std::iota(order.rbegin(), order.rend(), 0);
272+
if (opIdx == 1) {
273+
std::swap(order[0], order[1]);
277274
}
278275
return order;
279276
}
@@ -290,12 +287,13 @@ SmallVector<unsigned> getOrder(Attribute layout) {
290287
return order;
291288
}
292289
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
293-
auto rank = dotLayout.getWarpsPerCTA().size();
290+
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
291+
SmallVector<unsigned> order(rank);
294292
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
295-
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
293+
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
294+
} else {
295+
std::iota(order.rbegin(), order.rend(), 0);
296296
}
297-
SmallVector<unsigned> order(rank);
298-
std::iota(order.rbegin(), order.rend(), 0);
299297
return order;
300298
}
301299
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -1061,8 +1059,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10611059
return ::getWarpOrder(*this);
10621060
}
10631061
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
1064-
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1065-
/*kMajor*/ true);
1062+
return ::getOrder(*this);
10661063
}
10671064
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
10681065
ArrayRef<int64_t> tensorShape) const {
@@ -2045,7 +2042,6 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
20452042
int opIdx) const {
20462043
auto rank = shape.size();
20472044
auto warpsPerCTA = getWarpsPerCTA();
2048-
20492045
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
20502046
int numRepBatch =
20512047
rank == 3

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -827,15 +827,8 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
827827

828828
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
829829
return dotOperandMfmaToLinearLayout(*this, shape);
830-
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
831-
// FIXME [Dot LL]
832-
// Do this unconditionally
833-
auto largeKWidth = getKWidth() == 8;
834-
if (mma.isAmpere() && largeKWidth) {
835-
return ampereDotToLinearLayout(shape, *this);
836-
}
837-
} else if (auto dpasLayout =
838-
llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
830+
}
831+
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
839832
return dotOperandDpasToLinearLayout(*this, shape);
840833
}
841834

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ LogicalResult UpcastMXFPOp::verify() {
1717
auto xTy = getSrc().getType();
1818
auto scaleTy = getScale().getType();
1919

20-
if (xTy.getElementType() != FloatType::getBF16(getContext()) &&
21-
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
22-
return emitOpError("element type of the first operand must be bf16 or i8");
20+
if (xTy.getElementType() != FloatType::getBF16(getContext())) {
21+
return emitOpError("element type of the first operand must be bf16");
2322
}
2423

2524
if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
@@ -73,7 +72,7 @@ LogicalResult UpcastMXFPOp::verify() {
7372
}
7473

7574
LogicalResult UpcastMXFPOp::inferReturnTypes(
76-
MLIRContext *ctx, std::optional<Location> loc, ValueRange operands,
75+
MLIRContext *context, std::optional<Location> location, ValueRange operands,
7776
DictionaryAttr attributes, OpaqueProperties opaqueProperties,
7877
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
7978
auto xTy = cast<RankedTensorType>(operands[0].getType());
@@ -83,25 +82,21 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
8382

8483
auto encoding = xTy.getEncoding();
8584
if (!encoding) {
86-
return emitOptionalError(loc, "expected an encoding");
85+
return emitOptionalError(location, "expected an encoding");
8786
}
8887
if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
89-
return emitOptionalError(loc, "expected a dotOperand encoding");
88+
return emitOptionalError(location, "expected an mma layout encoding");
9089
}
91-
92-
if (typeEncoded == F8F6F4Type::E2M1) {
93-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
94-
auto newVEncoding = DotOperandEncodingAttr::get(
95-
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
96-
oldEncoding.getKWidth() * 2);
97-
auto newShape = SmallVector<int64_t>(xShape);
98-
newShape.back() *= 2;
99-
inferredReturnTypes.push_back(
100-
RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding));
101-
} else {
102-
inferredReturnTypes.push_back(xTy);
90+
if (xShape.size() < 2) {
91+
return emitOptionalError(location, "tensor rank must be at least 2");
10392
}
10493

94+
// For now we just return the input encoding. For fp4 we'll need to cast from
95+
// tf32 to fp16 encoding and multiply the shape by two
96+
assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) &&
97+
"NYI: only fp8e4m3 and fp8e5m2 are supported");
98+
99+
inferredReturnTypes.push_back(xTy);
105100
return success();
106101
}
107102

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class ScaledBlockedToMMAv2
406406
auto ctx = dotOp.getContext();
407407

408408
// Check that rhs scale is null
409-
assert(dotOp.getRhsScale() == nullptr && "rhs scale NYI");
409+
assert(dotOp.getRhsScale() == nullptr && "rhs scale must be null");
410410

411411
// operands
412412
auto a = dotOp.getLhs();
@@ -426,11 +426,10 @@ class ScaledBlockedToMMAv2
426426
}
427427
};
428428

429-
assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 ||
430-
aType == F8F6F4Type::E2M1) &&
431-
"NYI: lhs supports fp4 or fp8");
429+
assert(aType == F8F6F4Type::E4M3 ||
430+
aType == F8F6F4Type::E5M2 && "lhs just supports fp8");
432431
assert(bType == F8F6F4Type::E4M3 ||
433-
bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8");
432+
bType == F8F6F4Type::E5M2 && "rhs just supports fp8");
434433

435434
// TODO run accelerate matmul on A and B first to choose their layouts
436435
// Set return type
@@ -441,7 +440,6 @@ class ScaledBlockedToMMAv2
441440
auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA,
442441
rewriter.getBF16Type(), numWarps);
443442
auto CTALayout = getCTALayout(oldRetType.getEncoding());
444-
// TODO Use warpsPerTileV2
445443
SmallVector<unsigned> warpsPerCTA = {numWarps, 1};
446444
auto mmaEnc = NvidiaMmaEncodingAttr::get(ctx, /*versionMajor=*/versionMajor,
447445
/*versionMinor=*/0, warpsPerCTA,
@@ -454,39 +452,27 @@ class ScaledBlockedToMMAv2
454452
auto newAcc =
455453
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);
456454

457-
auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType](
458-
TypedValue<RankedTensorType> v, int idx,
459-
F8F6F4Type type) -> TypedValue<RankedTensorType> {
455+
auto toMMABf16 = [&newRetType, &rewriter, &ctx,
456+
&enumToType](TypedValue<RankedTensorType> v, int idx,
457+
F8F6F4Type type) {
458+
// MMAv2 Layout
460459
auto vType = v.getType();
461-
if (type == F8F6F4Type::E2M1) {
462-
// A bit too dynamically typed...
463-
// perhaps return ints in both cases?
464-
465-
auto retEnc = dyn_cast<NvidiaMmaEncodingAttr>(newRetType.getEncoding());
466-
auto newVEncoding = DotOperandEncodingAttr::get(
467-
ctx, idx, newRetType.getEncoding(), /*kWidth=*/4);
468-
auto newVType = RankedTensorType::get(
469-
vType.getShape(), vType.getElementType(), newVEncoding);
470-
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
471-
} else {
472-
assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3);
473-
auto newVEncoding = DotOperandEncodingAttr::get(
474-
ctx, idx, newRetType.getEncoding(), /*kWidth=*/8);
475-
auto newVType = RankedTensorType::get(
476-
vType.getShape(), vType.getElementType(), newVEncoding);
477-
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
478-
479-
// Bitcast
480-
auto vTypeFp8 = RankedTensorType::get(vType.getShape(),
481-
enumToType(type), newVEncoding);
482-
v = cast<TypedValue<RankedTensorType>>(
483-
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());
484-
485-
// Convert to bf16
486-
auto vTypeBf16 = RankedTensorType::get(
487-
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
488-
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
489-
}
460+
auto newVEncoding = DotOperandEncodingAttr::get(
461+
ctx, idx, newRetType.getEncoding(), enumToType((type)));
462+
auto newVType = RankedTensorType::get(
463+
v.getType().getShape(), v.getType().getElementType(), newVEncoding);
464+
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
465+
466+
// Bitcast
467+
auto vTypeFp8 = RankedTensorType::get(
468+
vType.getShape(), rewriter.getFloat8E4M3FNType(), newVEncoding);
469+
v = cast<TypedValue<RankedTensorType>>(
470+
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());
471+
472+
// Convert to bf16
473+
auto vTypeBf16 = RankedTensorType::get(
474+
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
475+
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
490476
};
491477
a = toMMABf16(a, 0, aType);
492478
b = toMMABf16(b, 1, bType);

lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,6 @@ class TritonGPUReduceDataDuplicationPass
4444
return;
4545
if (!cvtNeedsSharedMemory(srcType, dstType))
4646
return;
47-
// FIXME [Dot LL]
48-
// We support this one via LLs, as the LocalLoad path is buggy
49-
bool largeKWidth =
50-
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
51-
if (largeKWidth) {
52-
return;
53-
}
5447
auto srcOrder = triton::gpu::getOrder(srcEncoding);
5548
auto rank = srcOrder.size();
5649
SmallVector<unsigned> sharedOrder;

0 commit comments

Comments
 (0)