Skip to content

Commit 929142b

Browse files
authored
[MXFP] Implement SW emulation of dot_scale as a decomposition (#5475)
The goal of this PR is to remove the shuffles that were necessary to distribute the scale within `dot_scale`. Instead, we simply rely on the layout propagation to load the scales in the right layout from smem. In this PR we do a number of things: - Implement a new `fp4_to_fp` op with full forward and backward type inference - Remove UpcastMXFP - Remove all the complex layout choices within the `dot_scale` decomposition and instead rely on the accelerate-matmul pass and the remove-layout-conversions to do the right thing - Decompose `dot_scale` into simple triton ops so that the pass can be shared between all backends Still to do: - Splitting the `DecomposeScaledBlocked` into its own file and share the pass between nvidia and amd
1 parent ff77e98 commit 929142b

File tree

29 files changed

+1252
-1356
lines changed

29 files changed

+1252
-1356
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ class DialectInferLayoutInterface
8383
virtual LogicalResult
8484
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
8585
Attribute operandEncodingB) const = 0;
86+
87+
virtual LogicalResult
88+
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
89+
Attribute &outEnc, bool fwdInference,
90+
std::optional<Location> loc) const = 0;
8691
};
8792

8893
class DialectVerifyTensorLayoutInterface

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,10 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
456456
If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason.
457457
The compiler is still free to change it for better performance.
458458
}];
459+
let builders = [
460+
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "TypedValue<RankedTensorType>":$src)>
461+
];
462+
459463
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
460464
let results = (outs TT_Tensor:$result);
461465
let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)";

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -283,31 +283,29 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
283283
}];
284284
}
285285

286-
def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> {
287-
let summary = "Convert an mxfp tensor to bf16/fp16";
286+
def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
287+
let summary = "Upcast fp4 (e2m1) to fp";
288288

289289
let hasVerifier = 1;
290290

291291
let description = [{
292-
Compute the bf16 encoded in the given mxfp number as per
293-
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
294-
}];
295-
let arguments = (
296-
ins
297-
TT_Tensor:$src,
298-
TT_Tensor:$scale,
299-
TT_ScaleDotElemTypeAttr:$fp_type,
300-
BoolAttr:$fastMath
301-
);
302-
let results = (outs TT_Tensor:$result);
292+
Upcast fp4 (e2m1) represented packed as i8s to fp.
303293

304-
let assemblyFormat = [{
305-
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
294+
The lower 4 bits of the i8s represent the first fp4 element, and the upper 4 bits
295+
the second fp4 element.
296+
297+
The `axis` attribute specifies the axis along which the fp4 elements are packed.
306298
}];
307299

308-
let extraClassDeclaration = [{
309-
static RankedTensorType deduceOutputType(
310-
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
300+
let builders = [
301+
OpBuilder<(ins "TypedValue<RankedTensorType>":$src, "Type":$elemType, "int32_t":$axis)>
302+
];
303+
304+
let arguments = (ins RankedTensorOf<[I8]>:$src, I32Attr:$axis);
305+
let results = (outs TT_FloatTensor:$result);
306+
307+
let assemblyFormat = [{
308+
$src attr-dict `:` type($src) `->` type($result)
311309
}];
312310
}
313311

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "mlir/IR/PatternMatch.h"
2+
3+
namespace mlir::triton::gpu {
4+
5+
void populateDecomposeScaledBlockedPatterns(mlir::RewritePatternSet &patterns,
6+
int benefit);
7+
8+
} // namespace mlir::triton::gpu

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,23 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
693693
}
694694

695695
//-- ReshapeOp --
696+
697+
void ReshapeOp::build(OpBuilder &builder, OperationState &state,
698+
ArrayRef<int64_t> shape,
699+
TypedValue<RankedTensorType> src) {
700+
auto srcTy = src.getType();
701+
auto srcEnc = srcTy.getEncoding();
702+
Attribute dstEnc;
703+
if (srcEnc) {
704+
auto result = cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
705+
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape,
706+
dstEnc, state.location);
707+
assert(succeeded(result));
708+
}
709+
auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc);
710+
build(builder, state, dstTy, src);
711+
}
712+
696713
LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
697714
if (op.getEfficientLayout())
698715
return failure();
@@ -769,6 +786,10 @@ LogicalResult ReshapeOp::verify() {
769786
OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) {
770787
auto srcVal = getSrc();
771788
auto dstTy = getType();
789+
// Fold trivial cast
790+
if (srcVal.getType() == dstTy) {
791+
return srcVal;
792+
}
772793

773794
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
774795
const llvm::fltSemantics &semantic = resElemType.getFloatSemantics();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2843,6 +2843,98 @@ struct TritonGPUInferLayoutInterface
28432843
ArrayRef(enc.getCTAOrder()).drop_front(1)));
28442844
return success();
28452845
}
2846+
2847+
LogicalResult
2848+
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
2849+
Attribute &outEnc, bool fwdInference,
2850+
std::optional<Location> loc) const override {
2851+
// We implement two legacy layout propagations
2852+
// Once we fully migrate to LinearLayouts, we can remove these.
2853+
auto *ctx = getContext();
2854+
auto rank = shape.size();
2855+
// The output encoding will only be a legacy encoding if the axis is the
2856+
// fastest running dimension.
2857+
if (getOrder(inEnc)[axis] == 0) {
2858+
// Dot operand: double kWidth if kDim == axis.
2859+
if (auto dotEnc = mlir::dyn_cast<DotOperandEncodingAttr>(inEnc)) {
2860+
auto kWidth = dotEnc.getKWidth();
2861+
if (fwdInference) {
2862+
kWidth *= 2;
2863+
} else {
2864+
if (kWidth > 1) {
2865+
// bwd inference
2866+
kWidth /= 2;
2867+
} else {
2868+
return emitOptionalError(loc,
2869+
"Fp4ToFpOp requires at least 2 elements "
2870+
"per thread in the axis dimension");
2871+
}
2872+
}
2873+
outEnc = DotOperandEncodingAttr::get(ctx, dotEnc.getOpIdx(),
2874+
dotEnc.getParent(), kWidth);
2875+
return success();
2876+
}
2877+
2878+
// Blocked layout: double elemsPerThread[axis].
2879+
if (auto blockedEnc = mlir::dyn_cast<BlockedEncodingAttr>(inEnc)) {
2880+
auto sizePerThread = llvm::to_vector(blockedEnc.getSizePerThread());
2881+
if (fwdInference) {
2882+
sizePerThread[axis] *= 2;
2883+
} else {
2884+
if (sizePerThread[axis] > 1) {
2885+
sizePerThread[axis] /= 2;
2886+
} else {
2887+
return emitOptionalError(
2888+
loc, "Fp4ToFpOp requires at least 2 elements per "
2889+
"thread in the axis dimension");
2890+
}
2891+
}
2892+
outEnc = BlockedEncodingAttr::get(
2893+
ctx, sizePerThread, blockedEnc.getThreadsPerWarp(),
2894+
blockedEnc.getWarpsPerCTA(), blockedEnc.getOrder(),
2895+
blockedEnc.getCTALayout());
2896+
return success();
2897+
}
2898+
}
2899+
2900+
auto ll = toLinearLayout(shape, inEnc);
2901+
2902+
auto kRegister = StringAttr::get(ctx, "register");
2903+
auto outDims = llvm::to_vector(ll.getOutDimNames());
2904+
LinearLayout newLl = LinearLayout::empty();
2905+
if (fwdInference) {
2906+
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
2907+
newLl = split * ll;
2908+
// FIXME!!!!
2909+
// operator* transposes the output dimensions??!! WTF
2910+
newLl = newLl.transposeOuts(outDims);
2911+
} else {
2912+
// TODO This requires a division algorithm!
2913+
// Implement manually ll.divideLeft(split)
2914+
auto contiguousElems =
2915+
LinearEncodingAttr::get(ctx, ll).getContigPerThread();
2916+
if (contiguousElems[axis] > 1) {
2917+
LinearLayout::BasesT newBases;
2918+
for (const auto &basesDim : ll.getBases()) {
2919+
std::vector<std::vector<int32_t>> newBasesDim;
2920+
for (auto base : basesDim.second) {
2921+
if (base[axis] == 1) {
2922+
continue;
2923+
}
2924+
base[axis] /= 2;
2925+
newBasesDim.push_back(std::move(base));
2926+
}
2927+
newBases.insert({basesDim.first, std::move(newBasesDim)});
2928+
}
2929+
newLl = LinearLayout(std::move(newBases), std::move(outDims));
2930+
} else {
2931+
return emitOptionalError(loc, "Fp4ToFpOp requires at least 2 elements "
2932+
"per thread in the axis dimension");
2933+
}
2934+
}
2935+
outEnc = LinearEncodingAttr::get(ctx, newLl);
2936+
return success();
2937+
}
28462938
};
28472939

28482940
struct TritonGPUVerifyTensorLayoutInterface

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 53 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -329,121 +329,64 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
329329
patterns.add<CanonicalizeConvertFromSplit>(context);
330330
}
331331

332-
LogicalResult UpcastMXFPOp::verify() {
333-
auto fpType = getFpType();
334-
335-
auto xTy = getSrc().getType();
336-
auto scaleTy = getScale().getType();
337-
Builder b(getContext());
338-
if (xTy.getElementType() != b.getBF16Type() &&
339-
xTy.getElementType() != b.getF16Type() &&
340-
xTy.getElementType() != b.getI8Type()) {
341-
return emitOpError(
342-
"element type of the first operand must be bf16/fp16 or i8");
343-
}
344-
345-
if (scaleTy.getElementType() != b.getI8Type()) {
346-
return emitOpError("element type of the second operand must be uint8");
347-
}
348-
349-
auto xShape = xTy.getShape();
350-
auto scaleShape = scaleTy.getShape();
351-
352-
if (xShape.size() != scaleShape.size() || xShape.size() < 2) {
353-
return emitOpError(
354-
"operands must have the same number of dimensions, at least 2");
355-
}
356-
357-
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
358-
fpType == ScaleDotElemType::E5M2)) {
359-
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
360-
}
361-
362-
auto layoutX = xTy.getEncoding();
363-
auto layoutScale = scaleTy.getEncoding();
364-
if (bool(layoutX) != bool(layoutScale)) {
365-
return emitOpError(
366-
"Expected either both or neither operands to have an encoding");
367-
}
368-
// Nothing to check if no encoding. This is used to infer the return type in
369-
// AccelerateMatmul.cpp
370-
if (!layoutX) {
371-
return success();
372-
}
373-
374-
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
375-
if (!dotEncoding) {
376-
return emitOpError("Expected a DotOperandEncodingAttr for values");
377-
}
378-
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
379-
return emitOpError(
380-
"Expected a BlockOperandEncoding or LinearOperandEncoding "
381-
"for scales");
382-
}
383-
384-
if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
385-
// Necessary to keep all of the scales of a given block of values in the
386-
// same warp
387-
auto threadsPerWarp =
388-
cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp();
389-
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
390-
return emitOpError("Expected threads per warp to be {16, 2}");
332+
LogicalResult Fp4ToFpOp::verify() {
333+
auto srcTy = cast<RankedTensorType>(getSrc().getType());
334+
auto resTy = cast<RankedTensorType>(getResult().getType());
335+
auto rank = srcTy.getRank();
336+
337+
if (rank != resTy.getRank())
338+
return emitError() << "source rank " << rank << " != result rank "
339+
<< resTy.getRank();
340+
341+
auto srcShape = srcTy.getShape();
342+
auto resShape = resTy.getShape();
343+
auto axis = getAxis();
344+
345+
if (!(0 <= axis && axis < rank))
346+
return emitError() << "axis " << axis << " out of range for rank " << rank;
347+
348+
auto elemType = resTy.getElementType();
349+
if (!(elemType.isBF16() || elemType.isF16()))
350+
return emitError() << "only bf16 or f16 is supported for now, got "
351+
<< elemType;
352+
353+
for (int i = 0; i < rank; ++i) {
354+
if (i == axis) {
355+
if (resShape[i] != srcShape[i] * 2)
356+
return emitError() << "axis " << axis
357+
<< " dimension must be 2x source dimension (src="
358+
<< srcShape[i] << ", dst=" << resShape[i] << ")";
359+
} else {
360+
if (resShape[i] != srcShape[i])
361+
return emitError() << "dimension " << i
362+
<< " mismatch (src=" << srcShape[i]
363+
<< ", dst=" << resShape[i] << ", axis=" << axis
364+
<< ")";
391365
}
392366
}
393-
394-
// Change to support fp8 types
395-
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
396-
// Figure out the K dimension for the input A/B. For A/B scale, the K
397-
// dimension is always the last dimension.
398-
const int opIdx = dotEncoding.getOpIdx();
399-
const bool hasBatch = xShape.size() == 3;
400-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
401-
402-
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
403-
return emitOpError("K dimension of first operand must be 16 times "
404-
"larger than last/K dimension of the second operand");
405-
}
406-
407-
// Check other dimensions match too. For input A/B, we need to figure out the
408-
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
409-
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
410-
if (hasBatch && xShape[0] != scaleShape[0])
411-
return emitOpError("batch dimension must match between operands");
412-
if (xShape[mnIdx] != scaleShape[hasBatch]) {
413-
return emitOpError("M/N dimension must match between operands");
414-
}
415-
416367
return success();
417368
}
418369

419-
RankedTensorType
420-
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
421-
ScaleDotElemType inputElemType,
422-
Type outputElemType) {
423-
MLIRContext *ctx = inputTensor.getContext();
424-
auto xTy = inputTensor.getType();
425-
if (inputElemType != ScaleDotElemType::E2M1)
426-
return xTy;
427-
428-
auto xShape = xTy.getShape();
429-
auto newShape = llvm::to_vector(xShape);
430-
auto encoding = xTy.getEncoding();
431-
if (!encoding) {
432-
newShape.back() *= 2;
433-
return RankedTensorType::get(xShape, outputElemType);
434-
}
435-
436-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
437-
auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
438-
oldEncoding.getParent(),
439-
oldEncoding.getKWidth() * 2);
440-
// Figure out the K dimension for the input A/B, given that the return
441-
// type is upcasted A/B type so we need to update the proper dim size.
442-
const int opIdx = oldEncoding.getOpIdx();
443-
const bool hasBatch = xShape.size() == 3;
444-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
445-
newShape[kIdx] *= 2;
446-
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
370+
void Fp4ToFpOp::build(OpBuilder &builder, OperationState &state,
371+
TypedValue<RankedTensorType> src, Type elemType,
372+
int32_t axis) {
373+
auto srcTy = src.getType();
374+
auto shape = llvm::to_vector(srcTy.getShape());
375+
auto rank = srcTy.getRank();
376+
assert(0 <= axis && axis < rank);
377+
shape[axis] *= 2;
378+
379+
Attribute inEnc = srcTy.getEncoding();
380+
Attribute outEnc;
381+
auto result =
382+
inEnc.getDialect()
383+
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
384+
->inferFp4ToFpOpEncoding(shape, axis, inEnc, outEnc,
385+
/*fwdInference=*/true, state.location);
386+
assert(succeeded(result));
387+
388+
auto resultTy = RankedTensorType::get(shape, elemType, outEnc);
389+
build(builder, state, resultTy, src, axis);
447390
}
448391

449392
OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {

0 commit comments

Comments
 (0)