Skip to content

Commit 5eda5e3

Browse files
Revert "[MXFP] Implement SW emulation of dot_scale as a decomposition (#5475)"
This reverts commit 929142b.
1 parent 250c92d commit 5eda5e3

File tree

29 files changed

+1356
-1252
lines changed

29 files changed

+1356
-1252
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,6 @@ class DialectInferLayoutInterface
8383
virtual LogicalResult
8484
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
8585
Attribute operandEncodingB) const = 0;
86-
87-
virtual LogicalResult
88-
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
89-
Attribute &outEnc, bool fwdInference,
90-
std::optional<Location> loc) const = 0;
9186
};
9287

9388
class DialectVerifyTensorLayoutInterface

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,6 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
456456
If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason.
457457
The compiler is still free to change it for better performance.
458458
}];
459-
let builders = [
460-
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "TypedValue<RankedTensorType>":$src)>
461-
];
462-
463459
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
464460
let results = (outs TT_Tensor:$result);
465461
let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)";

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -283,29 +283,31 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
283283
}];
284284
}
285285

286-
def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
287-
let summary = "Upcast fp4 (e2m1) to fp";
286+
def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> {
287+
let summary = "Convert an mxfp tensor to bf16/fp16";
288288

289289
let hasVerifier = 1;
290290

291291
let description = [{
292-
Upcast fp4 (e2m1) represented packed as i8s to fp.
293-
294-
The lower 4 bits of the i8s represent the first fp4 element, and the upper 4 bits
295-
the second fp4 element.
296-
297-
The `axis` attribute specifies the axis along which the fp4 elements are packed.
292+
Compute the bf16 encoded in the given mxfp number as per
293+
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
298294
}];
299-
300-
let builders = [
301-
OpBuilder<(ins "TypedValue<RankedTensorType>":$src, "Type":$elemType, "int32_t":$axis)>
302-
];
303-
304-
let arguments = (ins RankedTensorOf<[I8]>:$src, I32Attr:$axis);
305-
let results = (outs TT_FloatTensor:$result);
295+
let arguments = (
296+
ins
297+
TT_Tensor:$src,
298+
TT_Tensor:$scale,
299+
TT_ScaleDotElemTypeAttr:$fp_type,
300+
BoolAttr:$fastMath
301+
);
302+
let results = (outs TT_Tensor:$result);
306303

307304
let assemblyFormat = [{
308-
$src attr-dict `:` type($src) `->` type($result)
305+
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
306+
}];
307+
308+
let extraClassDeclaration = [{
309+
static RankedTensorType deduceOutputType(
310+
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
309311
}];
310312
}
311313

include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -693,23 +693,6 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
693693
}
694694

695695
//-- ReshapeOp --
696-
697-
void ReshapeOp::build(OpBuilder &builder, OperationState &state,
698-
ArrayRef<int64_t> shape,
699-
TypedValue<RankedTensorType> src) {
700-
auto srcTy = src.getType();
701-
auto srcEnc = srcTy.getEncoding();
702-
Attribute dstEnc;
703-
if (srcEnc) {
704-
auto result = cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
705-
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape,
706-
dstEnc, state.location);
707-
assert(succeeded(result));
708-
}
709-
auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc);
710-
build(builder, state, dstTy, src);
711-
}
712-
713696
LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
714697
if (op.getEfficientLayout())
715698
return failure();
@@ -786,10 +769,6 @@ LogicalResult ReshapeOp::verify() {
786769
OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) {
787770
auto srcVal = getSrc();
788771
auto dstTy = getType();
789-
// Fold trivial cast
790-
if (srcVal.getType() == dstTy) {
791-
return srcVal;
792-
}
793772

794773
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
795774
const llvm::fltSemantics &semantic = resElemType.getFloatSemantics();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2882,98 +2882,6 @@ struct TritonGPUInferLayoutInterface
28822882
ArrayRef(enc.getCTAOrder()).drop_front(1)));
28832883
return success();
28842884
}
2885-
2886-
LogicalResult
2887-
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
2888-
Attribute &outEnc, bool fwdInference,
2889-
std::optional<Location> loc) const override {
2890-
// We implement two legacy layout propagations
2891-
// Once we fully migrate to LinearLayouts, we can remove these.
2892-
auto *ctx = getContext();
2893-
auto rank = shape.size();
2894-
// The output encoding will only be a legacy encoding if the axis is the
2895-
// fastest running dimension.
2896-
if (getOrder(inEnc)[axis] == 0) {
2897-
// Dot operand: double kWidth if kDim == axis.
2898-
if (auto dotEnc = mlir::dyn_cast<DotOperandEncodingAttr>(inEnc)) {
2899-
auto kWidth = dotEnc.getKWidth();
2900-
if (fwdInference) {
2901-
kWidth *= 2;
2902-
} else {
2903-
if (kWidth > 1) {
2904-
// bwd inference
2905-
kWidth /= 2;
2906-
} else {
2907-
return emitOptionalError(loc,
2908-
"Fp4ToFpOp requires at least 2 elements "
2909-
"per thread in the axis dimension");
2910-
}
2911-
}
2912-
outEnc = DotOperandEncodingAttr::get(ctx, dotEnc.getOpIdx(),
2913-
dotEnc.getParent(), kWidth);
2914-
return success();
2915-
}
2916-
2917-
// Blocked layout: double elemsPerThread[axis].
2918-
if (auto blockedEnc = mlir::dyn_cast<BlockedEncodingAttr>(inEnc)) {
2919-
auto sizePerThread = llvm::to_vector(blockedEnc.getSizePerThread());
2920-
if (fwdInference) {
2921-
sizePerThread[axis] *= 2;
2922-
} else {
2923-
if (sizePerThread[axis] > 1) {
2924-
sizePerThread[axis] /= 2;
2925-
} else {
2926-
return emitOptionalError(
2927-
loc, "Fp4ToFpOp requires at least 2 elements per "
2928-
"thread in the axis dimension");
2929-
}
2930-
}
2931-
outEnc = BlockedEncodingAttr::get(
2932-
ctx, sizePerThread, blockedEnc.getThreadsPerWarp(),
2933-
blockedEnc.getWarpsPerCTA(), blockedEnc.getOrder(),
2934-
blockedEnc.getCTALayout());
2935-
return success();
2936-
}
2937-
}
2938-
2939-
auto ll = toLinearLayout(shape, inEnc);
2940-
2941-
auto kRegister = StringAttr::get(ctx, "register");
2942-
auto outDims = llvm::to_vector(ll.getOutDimNames());
2943-
LinearLayout newLl = LinearLayout::empty();
2944-
if (fwdInference) {
2945-
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
2946-
newLl = split * ll;
2947-
// FIXME!!!!
2948-
// operator* transposes the output dimensions??!! WTF
2949-
newLl = newLl.transposeOuts(outDims);
2950-
} else {
2951-
// TODO This requires a division algorithm!
2952-
// Implement manually ll.divideLeft(split)
2953-
auto contiguousElems =
2954-
LinearEncodingAttr::get(ctx, ll).getContigPerThread();
2955-
if (contiguousElems[axis] > 1) {
2956-
LinearLayout::BasesT newBases;
2957-
for (const auto &basesDim : ll.getBases()) {
2958-
std::vector<std::vector<int32_t>> newBasesDim;
2959-
for (auto base : basesDim.second) {
2960-
if (base[axis] == 1) {
2961-
continue;
2962-
}
2963-
base[axis] /= 2;
2964-
newBasesDim.push_back(std::move(base));
2965-
}
2966-
newBases.insert({basesDim.first, std::move(newBasesDim)});
2967-
}
2968-
newLl = LinearLayout(std::move(newBases), std::move(outDims));
2969-
} else {
2970-
return emitOptionalError(loc, "Fp4ToFpOp requires at least 2 elements "
2971-
"per thread in the axis dimension");
2972-
}
2973-
}
2974-
outEnc = LinearEncodingAttr::get(ctx, newLl);
2975-
return success();
2976-
}
29772885
};
29782886

29792887
struct TritonGPUVerifyTensorLayoutInterface

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 110 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -331,64 +331,121 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
331331
patterns.add<CanonicalizeConvertFromSplit>(context);
332332
}
333333

334-
LogicalResult Fp4ToFpOp::verify() {
335-
auto srcTy = cast<RankedTensorType>(getSrc().getType());
336-
auto resTy = cast<RankedTensorType>(getResult().getType());
337-
auto rank = srcTy.getRank();
338-
339-
if (rank != resTy.getRank())
340-
return emitError() << "source rank " << rank << " != result rank "
341-
<< resTy.getRank();
342-
343-
auto srcShape = srcTy.getShape();
344-
auto resShape = resTy.getShape();
345-
auto axis = getAxis();
346-
347-
if (!(0 <= axis && axis < rank))
348-
return emitError() << "axis " << axis << " out of range for rank " << rank;
349-
350-
auto elemType = resTy.getElementType();
351-
if (!(elemType.isBF16() || elemType.isF16()))
352-
return emitError() << "only bf16 or f16 is supported for now, got "
353-
<< elemType;
354-
355-
for (int i = 0; i < rank; ++i) {
356-
if (i == axis) {
357-
if (resShape[i] != srcShape[i] * 2)
358-
return emitError() << "axis " << axis
359-
<< " dimension must be 2x source dimension (src="
360-
<< srcShape[i] << ", dst=" << resShape[i] << ")";
361-
} else {
362-
if (resShape[i] != srcShape[i])
363-
return emitError() << "dimension " << i
364-
<< " mismatch (src=" << srcShape[i]
365-
<< ", dst=" << resShape[i] << ", axis=" << axis
366-
<< ")";
334+
LogicalResult UpcastMXFPOp::verify() {
335+
auto fpType = getFpType();
336+
337+
auto xTy = getSrc().getType();
338+
auto scaleTy = getScale().getType();
339+
Builder b(getContext());
340+
if (xTy.getElementType() != b.getBF16Type() &&
341+
xTy.getElementType() != b.getF16Type() &&
342+
xTy.getElementType() != b.getI8Type()) {
343+
return emitOpError(
344+
"element type of the first operand must be bf16/fp16 or i8");
345+
}
346+
347+
if (scaleTy.getElementType() != b.getI8Type()) {
348+
return emitOpError("element type of the second operand must be uint8");
349+
}
350+
351+
auto xShape = xTy.getShape();
352+
auto scaleShape = scaleTy.getShape();
353+
354+
if (xShape.size() != scaleShape.size() || xShape.size() < 2) {
355+
return emitOpError(
356+
"operands must have the same number of dimensions, at least 2");
357+
}
358+
359+
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
360+
fpType == ScaleDotElemType::E5M2)) {
361+
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
362+
}
363+
364+
auto layoutX = xTy.getEncoding();
365+
auto layoutScale = scaleTy.getEncoding();
366+
if (bool(layoutX) != bool(layoutScale)) {
367+
return emitOpError(
368+
"Expected either both or neither operands to have an encoding");
369+
}
370+
// Nothing to check if no encoding. This is used to infer the return type in
371+
// AccelerateMatmul.cpp
372+
if (!layoutX) {
373+
return success();
374+
}
375+
376+
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
377+
if (!dotEncoding) {
378+
return emitOpError("Expected a DotOperandEncodingAttr for values");
379+
}
380+
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
381+
return emitOpError(
382+
"Expected a BlockOperandEncoding or LinearOperandEncoding "
383+
"for scales");
384+
}
385+
386+
if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
387+
// Necessary to keep all of the scales of a given block of values in the
388+
// same warp
389+
auto threadsPerWarp =
390+
cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp();
391+
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
392+
return emitOpError("Expected threads per warp to be {16, 2}");
367393
}
368394
}
395+
396+
// Change to support fp8 types
397+
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
398+
// Figure out the K dimension for the input A/B. For A/B scale, the K
399+
// dimension is always the last dimension.
400+
const int opIdx = dotEncoding.getOpIdx();
401+
const bool hasBatch = xShape.size() == 3;
402+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
403+
404+
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
405+
return emitOpError("K dimension of first operand must be 16 times "
406+
"larger than last/K dimension of the second operand");
407+
}
408+
409+
// Check other dimensions match too. For input A/B, we need to figure out the
410+
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
411+
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
412+
if (hasBatch && xShape[0] != scaleShape[0])
413+
return emitOpError("batch dimension must match between operands");
414+
if (xShape[mnIdx] != scaleShape[hasBatch]) {
415+
return emitOpError("M/N dimension must match between operands");
416+
}
417+
369418
return success();
370419
}
371420

372-
void Fp4ToFpOp::build(OpBuilder &builder, OperationState &state,
373-
TypedValue<RankedTensorType> src, Type elemType,
374-
int32_t axis) {
375-
auto srcTy = src.getType();
376-
auto shape = llvm::to_vector(srcTy.getShape());
377-
auto rank = srcTy.getRank();
378-
assert(0 <= axis && axis < rank);
379-
shape[axis] *= 2;
380-
381-
Attribute inEnc = srcTy.getEncoding();
382-
Attribute outEnc;
383-
auto result =
384-
inEnc.getDialect()
385-
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
386-
->inferFp4ToFpOpEncoding(shape, axis, inEnc, outEnc,
387-
/*fwdInference=*/true, state.location);
388-
assert(succeeded(result));
389-
390-
auto resultTy = RankedTensorType::get(shape, elemType, outEnc);
391-
build(builder, state, resultTy, src, axis);
421+
RankedTensorType
422+
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
423+
ScaleDotElemType inputElemType,
424+
Type outputElemType) {
425+
MLIRContext *ctx = inputTensor.getContext();
426+
auto xTy = inputTensor.getType();
427+
if (inputElemType != ScaleDotElemType::E2M1)
428+
return xTy;
429+
430+
auto xShape = xTy.getShape();
431+
auto newShape = llvm::to_vector(xShape);
432+
auto encoding = xTy.getEncoding();
433+
if (!encoding) {
434+
newShape.back() *= 2;
435+
return RankedTensorType::get(xShape, outputElemType);
436+
}
437+
438+
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
439+
auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
440+
oldEncoding.getParent(),
441+
oldEncoding.getKWidth() * 2);
442+
// Figure out the K dimension for the input A/B, given that the return
443+
// type is upcasted A/B type so we need to update the proper dim size.
444+
const int opIdx = oldEncoding.getOpIdx();
445+
const bool hasBatch = xShape.size() == 3;
446+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
447+
newShape[kIdx] *= 2;
448+
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
392449
}
393450

394451
OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {

0 commit comments

Comments
 (0)