Skip to content

Commit 4daa467

Browse files
authored
Implement scaled_dot(mxfp8, fp8) via mma (#4795)
Initial implementation using mma. Missing to test that it plays ball with the pipeliner.
1 parent d39ee1f commit 4daa467

File tree

22 files changed

+719
-25
lines changed

22 files changed

+719
-25
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,18 @@ using namespace mlir::triton;
125125
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
126126

127127
// Constants
128+
#define int_val(bitwidth, val) \
129+
LLVM::createLLVMIntegerConstant(rewriter, loc, bitwidth, val)
128130
#define i1_val(val) LLVM::createConstantI1(loc, rewriter, val)
129131
#define true_val() i1_val(true)
130132
#define false_val() i1_val(false)
131133
#define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__)
132134
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
133135
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
136+
#define i8_val(val) int_val(8, val)
137+
#define i16_val(val) int_val(16, val)
134138
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
135139
#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__)
136-
#define int_val(width, val) \
137-
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
138140
#define tid_val() getThreadId(rewriter, loc)
139141

140142
// Attributes

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ class DotLike : public TraitBase<ConcreteType, DotLike> {
8181
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
8282
return op->emitOpError("expected all operands to have the same rank");
8383
// Check if the first two operands share a common dimension
84-
if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
85-
return op->emitOpError("expected the last dimension of the first operand "
86-
"to be equal to the second-to-last dimension of "
87-
"the second operand");
84+
// TODO: enable back with an interface to support scaled dot.
85+
// if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
86+
// return op->emitOpError("expected the last dimension of the first
87+
// operand "
88+
// "to be equal to the second-to-last dimension of
89+
// " "the second operand");
8890
// Check the batch dimension
8991
if (aShape.size() == 3 &&
9092
(aShape[0] != cShape[0] || bShape[0] != cShape[0]))

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,18 @@ def TT_InputPrecisionAttr : I32EnumAttr<
119119
let cppNamespace = "::mlir::triton";
120120
}
121121

122+
// Type for F8F6F4 kind of floats.
123+
def TT_F8F6F4TypeAttr : I32EnumAttr<
124+
"F8F6F4Type", "",
125+
[
126+
I32EnumAttrCase<"E4M3", 0, "e4m3">,
127+
I32EnumAttrCase<"E5M2", 1, "e5m2">,
128+
I32EnumAttrCase<"E2M3", 2, "e2m3">,
129+
I32EnumAttrCase<"E3M2", 3, "e3m2">,
130+
I32EnumAttrCase<"E2M1", 4, "e2m1">
131+
132+
]>{
133+
let cppNamespace = "::mlir::triton";
134+
}
135+
122136
#endif

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,43 @@ def TT_DotOp : TT_Op<"dot", [Pure,
673673
let hasVerifier = 1;
674674
}
675675

676+
677+
//
678+
// DotScaled Op
679+
//
680+
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
681+
DotLike,
682+
TypesMatchWith<"result's type matches accumulator's type",
683+
"d", "c", "$_self">]> {
684+
let summary = "dot_scaled";
685+
686+
let description = [{
687+
$d = matrix_multiply(scale($lhs, $lhs_scale), scale($rhs, $rhs_scale)) + $c.
688+
Where scale(x, s) is a function that applies the scale per block following microscaling spec.
689+
}];
690+
691+
let arguments = (
692+
ins
693+
// inputs are integer types as they are packed types and we currently
694+
// don't have a representation for those.
695+
TT_IntTensor:$lhs,
696+
TT_IntTensor:$rhs,
697+
TT_FloatTensor:$c,
698+
TT_IntTensor:$lhs_scale,
699+
Optional<TT_IntTensor>:$rhs_scale,
700+
TT_F8F6F4TypeAttr:$lhs_type,
701+
TT_F8F6F4TypeAttr:$rhs_type
702+
);
703+
704+
let results = (outs TT_FloatTensor:$d);
705+
706+
// Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file
707+
let assemblyFormat = [{
708+
$lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict
709+
`:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d)
710+
}];
711+
}
712+
676713
//
677714
// Reduce Op
678715
//

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,24 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
256256
}];
257257
}
258258

259+
def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
260+
let summary = "Convert an mxfp tensor to bf16";
261+
262+
let hasVerifier = 1;
263+
264+
let description = [{
265+
Compute the bf16 encoded in the given mxfp number as per
266+
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
267+
}];
268+
let arguments = (ins
269+
TT_Tensor:$src,
270+
TT_Tensor:$scale,
271+
TT_F8F6F4TypeAttr:$fp_type);
272+
let results = (outs TT_Tensor:$result);
273+
274+
let assemblyFormat = [{
275+
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
276+
}];
277+
}
278+
259279
#endif

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ class SharedEncodingAttr;
2828
// Version = 3: <m, n, k>
2929
SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
3030
const ArrayRef<int64_t> &shape,
31-
RankedTensorType type,
32-
int numWarps);
31+
Type type, int numWarps);
3332

3433
// Return true if the Load uses block pointer.
3534
bool isLoadFromTensorPtr(triton::LoadOp op);

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
553553
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
554554
GenericOpPattern<triton::ExperimentalTensormapCreateOp>,
555555
GenericOpPattern<triton::ExperimentalTensormapFenceproxyAcquireOp>,
556-
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
557-
context);
556+
// this assumes the right layout will be set later for dot scaled.
557+
GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
558+
TritonFuncOpPattern>(typeConverter, context);
558559
}
559560

560561
//

lib/Dialect/TritonGPU/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_triton_library(TritonGPUIR
22
Dialect.cpp
33
LinearLayoutConversions.cpp
4+
Ops.cpp
45
Types.cpp
56

67
DEPENDS

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3425,9 +3425,6 @@ void TritonGPUDialect::initialize() {
34253425
addInterfaces<TritonGPUInferLayoutInterface>();
34263426
}
34273427

3428-
#define GET_OP_CLASSES
3429-
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
3430-
34313428
// verify TritonGPU ops
34323429
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
34333430
NamedAttribute attr) {

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "mlir/IR/BuiltinTypes.h"
2+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
#include "triton/Dialect/Triton/IR/Dialect.h"
4+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"
7+
#include "llvm/Support/raw_ostream.h"
8+
9+
#define GET_OP_CLASSES
10+
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
11+
12+
namespace mlir::triton::gpu {
13+
14+
LogicalResult UpcastMXFPOp::verify() {
15+
auto fpType = getFpType();
16+
17+
auto xTy = getSrc().getType();
18+
auto scaleTy = getScale().getType();
19+
20+
if (xTy.getElementType() != FloatType::getBF16(getContext())) {
21+
return emitOpError("element type of the first operand must be bf16");
22+
}
23+
24+
if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
25+
return emitOpError("element type of the second operand must be uint8");
26+
}
27+
28+
auto xShape = xTy.getShape();
29+
auto scaleShape = scaleTy.getShape();
30+
31+
if (xShape.size() != scaleShape.size() || xShape.size() < 2) {
32+
return emitOpError(
33+
"operands must have the same number of dimensions, at least 2");
34+
}
35+
36+
if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 ||
37+
fpType == F8F6F4Type::E5M2)) {
38+
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
39+
}
40+
41+
// Change to support fp8 types
42+
const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1;
43+
44+
if (xShape.back() != (32 / elems_packed) * scaleShape.back()) {
45+
return emitOpError("last dimension of first operand must be 16 times "
46+
"larger than that of the second operand");
47+
}
48+
49+
if (!std::equal(xShape.begin(), xShape.end() - 1, scaleShape.begin())) {
50+
return emitOpError(
51+
"all dimensions except the last must match between operands");
52+
}
53+
54+
auto layoutX = xTy.getEncoding();
55+
if (!layoutX || !isa<DotOperandEncodingAttr>(layoutX)) {
56+
return emitOpError("Expected a DotOperandEncodingAttr for values");
57+
}
58+
auto layoutScale = scaleTy.getEncoding();
59+
if (!layoutScale || !isa<BlockedEncodingAttr>(layoutScale)) {
60+
return emitOpError("Expected a BlockOperandEncoding for scales");
61+
}
62+
auto blockedScale = cast<BlockedEncodingAttr>(layoutScale);
63+
64+
// Necessary to keep all of the scales of a given block of values in the same
65+
// warp
66+
auto threadsPerWarp = blockedScale.getThreadsPerWarp();
67+
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
68+
return emitOpError("Expected threads per warp to be {16, 2}");
69+
}
70+
71+
return success();
72+
}
73+
74+
LogicalResult UpcastMXFPOp::inferReturnTypes(
75+
MLIRContext *context, std::optional<Location> location, ValueRange operands,
76+
DictionaryAttr attributes, OpaqueProperties opaqueProperties,
77+
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
78+
auto xTy = cast<RankedTensorType>(operands[0].getType());
79+
auto properties = opaqueProperties.as<const Properties *>();
80+
auto typeEncoded = properties->fp_type.getValue();
81+
auto xShape = xTy.getShape();
82+
83+
auto encoding = xTy.getEncoding();
84+
if (!encoding) {
85+
return emitOptionalError(location, "expected an encoding");
86+
}
87+
if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
88+
return emitOptionalError(location, "expected an mma layout encoding");
89+
}
90+
if (xShape.size() < 2) {
91+
return emitOptionalError(location, "tensor rank must be at least 2");
92+
}
93+
94+
// For now we just return the input encoding. For fp4 we'll need to cast from
95+
// tf32 to fp16 encoding and multiply the shape by two
96+
assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) &&
97+
"NYI: only fp8e4m3 and fp8e5m2 are supported");
98+
99+
inferredReturnTypes.push_back(xTy);
100+
return success();
101+
}
102+
103+
} // namespace mlir::triton::gpu

0 commit comments

Comments
 (0)