Skip to content

Commit 250c92d

Browse files
Merge commit '929142bb84e661c100a80cd2d924b5847dff98ed'
2 parents fff08ef + 929142b commit 250c92d

File tree

29 files changed

+1252
-1356
lines changed

29 files changed

+1252
-1356
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ class DialectInferLayoutInterface
8383
virtual LogicalResult
8484
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
8585
Attribute operandEncodingB) const = 0;
86+
87+
virtual LogicalResult
88+
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
89+
Attribute &outEnc, bool fwdInference,
90+
std::optional<Location> loc) const = 0;
8691
};
8792

8893
class DialectVerifyTensorLayoutInterface

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,10 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
456456
If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason.
457457
The compiler is still free to change it for better performance.
458458
}];
459+
let builders = [
460+
OpBuilder<(ins "ArrayRef<int64_t>":$shape, "TypedValue<RankedTensorType>":$src)>
461+
];
462+
459463
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
460464
let results = (outs TT_Tensor:$result);
461465
let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)";

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -283,31 +283,29 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
283283
}];
284284
}
285285

286-
def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> {
287-
let summary = "Convert an mxfp tensor to bf16/fp16";
286+
def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
287+
let summary = "Upcast fp4 (e2m1) to fp";
288288

289289
let hasVerifier = 1;
290290

291291
let description = [{
292-
Compute the bf16 encoded in the given mxfp number as per
293-
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
294-
}];
295-
let arguments = (
296-
ins
297-
TT_Tensor:$src,
298-
TT_Tensor:$scale,
299-
TT_ScaleDotElemTypeAttr:$fp_type,
300-
BoolAttr:$fastMath
301-
);
302-
let results = (outs TT_Tensor:$result);
292+
Upcast fp4 (e2m1) represented packed as i8s to fp.
303293

304-
let assemblyFormat = [{
305-
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
294+
The lower 4 bits of the i8s represent the first fp4 element, and the upper 4 bits
295+
the second fp4 element.
296+
297+
The `axis` attribute specifies the axis along which the fp4 elements are packed.
306298
}];
307299

308-
let extraClassDeclaration = [{
309-
static RankedTensorType deduceOutputType(
310-
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
300+
let builders = [
301+
OpBuilder<(ins "TypedValue<RankedTensorType>":$src, "Type":$elemType, "int32_t":$axis)>
302+
];
303+
304+
let arguments = (ins RankedTensorOf<[I8]>:$src, I32Attr:$axis);
305+
let results = (outs TT_FloatTensor:$result);
306+
307+
let assemblyFormat = [{
308+
$src attr-dict `:` type($src) `->` type($result)
311309
}];
312310
}
313311

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "mlir/IR/PatternMatch.h"
2+
3+
namespace mlir::triton::gpu {
4+
5+
void populateDecomposeScaledBlockedPatterns(mlir::RewritePatternSet &patterns,
6+
int benefit);
7+
8+
} // namespace mlir::triton::gpu

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,23 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
693693
}
694694

695695
//-- ReshapeOp --
696+
697+
void ReshapeOp::build(OpBuilder &builder, OperationState &state,
698+
ArrayRef<int64_t> shape,
699+
TypedValue<RankedTensorType> src) {
700+
auto srcTy = src.getType();
701+
auto srcEnc = srcTy.getEncoding();
702+
Attribute dstEnc;
703+
if (srcEnc) {
704+
auto result = cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
705+
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, shape,
706+
dstEnc, state.location);
707+
assert(succeeded(result));
708+
}
709+
auto dstTy = RankedTensorType::get(shape, srcTy.getElementType(), dstEnc);
710+
build(builder, state, dstTy, src);
711+
}
712+
696713
LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
697714
if (op.getEfficientLayout())
698715
return failure();
@@ -769,6 +786,10 @@ LogicalResult ReshapeOp::verify() {
769786
OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) {
770787
auto srcVal = getSrc();
771788
auto dstTy = getType();
789+
// Fold trivial cast
790+
if (srcVal.getType() == dstTy) {
791+
return srcVal;
792+
}
772793

773794
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
774795
const llvm::fltSemantics &semantic = resElemType.getFloatSemantics();

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2882,6 +2882,98 @@ struct TritonGPUInferLayoutInterface
28822882
ArrayRef(enc.getCTAOrder()).drop_front(1)));
28832883
return success();
28842884
}
2885+
2886+
LogicalResult
2887+
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
2888+
Attribute &outEnc, bool fwdInference,
2889+
std::optional<Location> loc) const override {
2890+
// We implement two legacy layout propagations
2891+
// Once we fully migrate to LinearLayouts, we can remove these.
2892+
auto *ctx = getContext();
2893+
auto rank = shape.size();
2894+
// The output encoding will only be a legacy encoding if the axis is the
2895+
// fastest running dimension.
2896+
if (getOrder(inEnc)[axis] == 0) {
2897+
// Dot operand: double kWidth if kDim == axis.
2898+
if (auto dotEnc = mlir::dyn_cast<DotOperandEncodingAttr>(inEnc)) {
2899+
auto kWidth = dotEnc.getKWidth();
2900+
if (fwdInference) {
2901+
kWidth *= 2;
2902+
} else {
2903+
if (kWidth > 1) {
2904+
// bwd inference
2905+
kWidth /= 2;
2906+
} else {
2907+
return emitOptionalError(loc,
2908+
"Fp4ToFpOp requires at least 2 elements "
2909+
"per thread in the axis dimension");
2910+
}
2911+
}
2912+
outEnc = DotOperandEncodingAttr::get(ctx, dotEnc.getOpIdx(),
2913+
dotEnc.getParent(), kWidth);
2914+
return success();
2915+
}
2916+
2917+
// Blocked layout: double elemsPerThread[axis].
2918+
if (auto blockedEnc = mlir::dyn_cast<BlockedEncodingAttr>(inEnc)) {
2919+
auto sizePerThread = llvm::to_vector(blockedEnc.getSizePerThread());
2920+
if (fwdInference) {
2921+
sizePerThread[axis] *= 2;
2922+
} else {
2923+
if (sizePerThread[axis] > 1) {
2924+
sizePerThread[axis] /= 2;
2925+
} else {
2926+
return emitOptionalError(
2927+
loc, "Fp4ToFpOp requires at least 2 elements per "
2928+
"thread in the axis dimension");
2929+
}
2930+
}
2931+
outEnc = BlockedEncodingAttr::get(
2932+
ctx, sizePerThread, blockedEnc.getThreadsPerWarp(),
2933+
blockedEnc.getWarpsPerCTA(), blockedEnc.getOrder(),
2934+
blockedEnc.getCTALayout());
2935+
return success();
2936+
}
2937+
}
2938+
2939+
auto ll = toLinearLayout(shape, inEnc);
2940+
2941+
auto kRegister = StringAttr::get(ctx, "register");
2942+
auto outDims = llvm::to_vector(ll.getOutDimNames());
2943+
LinearLayout newLl = LinearLayout::empty();
2944+
if (fwdInference) {
2945+
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
2946+
newLl = split * ll;
2947+
// FIXME!!!!
2948+
// operator* transposes the output dimensions??!! WTF
2949+
newLl = newLl.transposeOuts(outDims);
2950+
} else {
2951+
// TODO This requires a division algorithm!
2952+
// Implement manually ll.divideLeft(split)
2953+
auto contiguousElems =
2954+
LinearEncodingAttr::get(ctx, ll).getContigPerThread();
2955+
if (contiguousElems[axis] > 1) {
2956+
LinearLayout::BasesT newBases;
2957+
for (const auto &basesDim : ll.getBases()) {
2958+
std::vector<std::vector<int32_t>> newBasesDim;
2959+
for (auto base : basesDim.second) {
2960+
if (base[axis] == 1) {
2961+
continue;
2962+
}
2963+
base[axis] /= 2;
2964+
newBasesDim.push_back(std::move(base));
2965+
}
2966+
newBases.insert({basesDim.first, std::move(newBasesDim)});
2967+
}
2968+
newLl = LinearLayout(std::move(newBases), std::move(outDims));
2969+
} else {
2970+
return emitOptionalError(loc, "Fp4ToFpOp requires at least 2 elements "
2971+
"per thread in the axis dimension");
2972+
}
2973+
}
2974+
outEnc = LinearEncodingAttr::get(ctx, newLl);
2975+
return success();
2976+
}
28852977
};
28862978

28872979
struct TritonGPUVerifyTensorLayoutInterface

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 53 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -331,121 +331,64 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
331331
patterns.add<CanonicalizeConvertFromSplit>(context);
332332
}
333333

334-
LogicalResult UpcastMXFPOp::verify() {
335-
auto fpType = getFpType();
336-
337-
auto xTy = getSrc().getType();
338-
auto scaleTy = getScale().getType();
339-
Builder b(getContext());
340-
if (xTy.getElementType() != b.getBF16Type() &&
341-
xTy.getElementType() != b.getF16Type() &&
342-
xTy.getElementType() != b.getI8Type()) {
343-
return emitOpError(
344-
"element type of the first operand must be bf16/fp16 or i8");
345-
}
346-
347-
if (scaleTy.getElementType() != b.getI8Type()) {
348-
return emitOpError("element type of the second operand must be uint8");
349-
}
350-
351-
auto xShape = xTy.getShape();
352-
auto scaleShape = scaleTy.getShape();
353-
354-
if (xShape.size() != scaleShape.size() || xShape.size() < 2) {
355-
return emitOpError(
356-
"operands must have the same number of dimensions, at least 2");
357-
}
358-
359-
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
360-
fpType == ScaleDotElemType::E5M2)) {
361-
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
362-
}
363-
364-
auto layoutX = xTy.getEncoding();
365-
auto layoutScale = scaleTy.getEncoding();
366-
if (bool(layoutX) != bool(layoutScale)) {
367-
return emitOpError(
368-
"Expected either both or neither operands to have an encoding");
369-
}
370-
// Nothing to check if no encoding. This is used to infer the return type in
371-
// AccelerateMatmul.cpp
372-
if (!layoutX) {
373-
return success();
374-
}
375-
376-
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
377-
if (!dotEncoding) {
378-
return emitOpError("Expected a DotOperandEncodingAttr for values");
379-
}
380-
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
381-
return emitOpError(
382-
"Expected a BlockOperandEncoding or LinearOperandEncoding "
383-
"for scales");
384-
}
385-
386-
if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
387-
// Necessary to keep all of the scales of a given block of values in the
388-
// same warp
389-
auto threadsPerWarp =
390-
cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp();
391-
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
392-
return emitOpError("Expected threads per warp to be {16, 2}");
334+
LogicalResult Fp4ToFpOp::verify() {
335+
auto srcTy = cast<RankedTensorType>(getSrc().getType());
336+
auto resTy = cast<RankedTensorType>(getResult().getType());
337+
auto rank = srcTy.getRank();
338+
339+
if (rank != resTy.getRank())
340+
return emitError() << "source rank " << rank << " != result rank "
341+
<< resTy.getRank();
342+
343+
auto srcShape = srcTy.getShape();
344+
auto resShape = resTy.getShape();
345+
auto axis = getAxis();
346+
347+
if (!(0 <= axis && axis < rank))
348+
return emitError() << "axis " << axis << " out of range for rank " << rank;
349+
350+
auto elemType = resTy.getElementType();
351+
if (!(elemType.isBF16() || elemType.isF16()))
352+
return emitError() << "only bf16 or f16 is supported for now, got "
353+
<< elemType;
354+
355+
for (int i = 0; i < rank; ++i) {
356+
if (i == axis) {
357+
if (resShape[i] != srcShape[i] * 2)
358+
return emitError() << "axis " << axis
359+
<< " dimension must be 2x source dimension (src="
360+
<< srcShape[i] << ", dst=" << resShape[i] << ")";
361+
} else {
362+
if (resShape[i] != srcShape[i])
363+
return emitError() << "dimension " << i
364+
<< " mismatch (src=" << srcShape[i]
365+
<< ", dst=" << resShape[i] << ", axis=" << axis
366+
<< ")";
393367
}
394368
}
395-
396-
// Change to support fp8 types
397-
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
398-
// Figure out the K dimension for the input A/B. For A/B scale, the K
399-
// dimension is always the last dimension.
400-
const int opIdx = dotEncoding.getOpIdx();
401-
const bool hasBatch = xShape.size() == 3;
402-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
403-
404-
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
405-
return emitOpError("K dimension of first operand must be 16 times "
406-
"larger than last/K dimension of the second operand");
407-
}
408-
409-
// Check other dimensions match too. For input A/B, we need to figure out the
410-
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
411-
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
412-
if (hasBatch && xShape[0] != scaleShape[0])
413-
return emitOpError("batch dimension must match between operands");
414-
if (xShape[mnIdx] != scaleShape[hasBatch]) {
415-
return emitOpError("M/N dimension must match between operands");
416-
}
417-
418369
return success();
419370
}
420371

421-
RankedTensorType
422-
UpcastMXFPOp::deduceOutputType(TypedValue<RankedTensorType> inputTensor,
423-
ScaleDotElemType inputElemType,
424-
Type outputElemType) {
425-
MLIRContext *ctx = inputTensor.getContext();
426-
auto xTy = inputTensor.getType();
427-
if (inputElemType != ScaleDotElemType::E2M1)
428-
return xTy;
429-
430-
auto xShape = xTy.getShape();
431-
auto newShape = llvm::to_vector(xShape);
432-
auto encoding = xTy.getEncoding();
433-
if (!encoding) {
434-
newShape.back() *= 2;
435-
return RankedTensorType::get(xShape, outputElemType);
436-
}
437-
438-
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
439-
auto newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
440-
oldEncoding.getParent(),
441-
oldEncoding.getKWidth() * 2);
442-
// Figure out the K dimension for the input A/B, given that the return
443-
// type is upcasted A/B type so we need to update the proper dim size.
444-
const int opIdx = oldEncoding.getOpIdx();
445-
const bool hasBatch = xShape.size() == 3;
446-
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
447-
newShape[kIdx] *= 2;
448-
return RankedTensorType::get(newShape, outputElemType, newVEncoding);
372+
void Fp4ToFpOp::build(OpBuilder &builder, OperationState &state,
373+
TypedValue<RankedTensorType> src, Type elemType,
374+
int32_t axis) {
375+
auto srcTy = src.getType();
376+
auto shape = llvm::to_vector(srcTy.getShape());
377+
auto rank = srcTy.getRank();
378+
assert(0 <= axis && axis < rank);
379+
shape[axis] *= 2;
380+
381+
Attribute inEnc = srcTy.getEncoding();
382+
Attribute outEnc;
383+
auto result =
384+
inEnc.getDialect()
385+
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
386+
->inferFp4ToFpOpEncoding(shape, axis, inEnc, outEnc,
387+
/*fwdInference=*/true, state.location);
388+
assert(succeeded(result));
389+
390+
auto resultTy = RankedTensorType::get(shape, elemType, outEnc);
391+
build(builder, state, resultTy, src, axis);
449392
}
450393

451394
OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {

0 commit comments

Comments
 (0)