Skip to content

Commit 7c471a9

Browse files
lhutton1tatwaichong
authored andcommitted
[mlir][tosa] Add support for matmul_t_block_scaled (llvm#163433)
This commit adds support for the MATMUL_T_BLOCK_SCALED operation from the EXT_MXFP extension. This includes: - Operation definition in TosaOps.td - Micro-scaling supported types definition - Shape inference and verifiers - Validation pass checks to ensure usage is only valid when the target environment includes ext-mxfp and at least v1.1.draft of the specification. As part of this commit, a notion of EXT_MXFP is also added. The extension can be specified as part of the target environment and can only be used if the specification version is at least 1.1. Note: currently it excludes support for mxint8. This will be added in a later commit. Note: this commit adds support as defined in the spec in arm/tosa-specification@063846a. EXT_MXFP extension is considered experimental and subject to breaking change. Co-authored-by: Tat Wai Chong <[email protected]>
1 parent 025e2db commit 7c471a9

17 files changed

+597
-45
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
5454
/// and provide utilities around the TOSA specification version.
5555
class TosaSpecificationVersion {
5656
public:
57+
TosaSpecificationVersion() = default;
58+
5759
TosaSpecificationVersion(uint32_t major, uint32_t minor)
5860
: majorVersion(major), minorVersion(minor) {}
5961
TosaSpecificationVersion(SpecificationVersion version)
@@ -83,6 +85,10 @@ class TosaSpecificationVersion {
8385
}
8486
};
8587

88+
TosaSpecificationVersion getMinVersion(const Profile &profile);
89+
TosaSpecificationVersion getMinVersion(const Extension &extension);
90+
TosaSpecificationVersion getMinVersion(const Level &level);
91+
8692
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
8793

8894
/// This class represents the capability enabled in the target implementation
@@ -91,22 +97,19 @@ llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
9197
class TargetEnv {
9298
public:
9399
TargetEnv() {}
94-
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
95-
const ArrayRef<Profile> &profiles,
96-
const ArrayRef<Extension> &extensions)
97-
: specificationVersion(specificationVersion), level(level) {
98-
enabledProfiles.insert_range(profiles);
99-
enabledExtensions.insert_range(extensions);
100-
}
101100

102-
explicit TargetEnv(TargetEnvAttr targetAttr)
103-
: TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
104-
targetAttr.getProfiles(), targetAttr.getExtensions()) {}
101+
static FailureOr<TargetEnv>
102+
createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc);
103+
104+
static LogicalResult verifyTargetInformation(TargetEnvAttr targetAttr,
105+
Location targetAttrLoc);
105106

106107
void addProfile(Profile p) { enabledProfiles.insert(p); }
107108
void addExtension(Extension e) { enabledExtensions.insert(e); }
108109

109-
SpecificationVersion getSpecVersion() const { return specificationVersion; }
110+
TosaSpecificationVersion getSpecVersion() const {
111+
return specificationVersion;
112+
}
110113

111114
TosaLevel getLevel() const {
112115
if (level == Level::eightK)
@@ -140,7 +143,17 @@ class TargetEnv {
140143
}
141144

142145
private:
143-
SpecificationVersion specificationVersion;
146+
// Require target information is verified before constructing, via the use of
147+
// `createTargetEnvFromAttr`.
148+
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
149+
const ArrayRef<Profile> &profiles,
150+
const ArrayRef<Extension> &extensions)
151+
: specificationVersion(specificationVersion), level(level) {
152+
enabledProfiles.insert_range(profiles);
153+
enabledExtensions.insert_range(extensions);
154+
}
155+
156+
TosaSpecificationVersion specificationVersion;
144157
Level level;
145158
llvm::SmallSet<Profile, 3> enabledProfiles;
146159
llvm::SmallSet<Extension, 13> enabledExtensions;

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,18 @@ extensionComplianceMap = {
554554
allOf},
555555
{{Extension::bf16},
556556
{{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
557+
{"tosa.matmul_t_block_scaled",
558+
{{{Extension::mxfp},
559+
{{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T},
560+
SpecificationVersion::V_1_1_DRAFT},
561+
{{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T},
562+
SpecificationVersion::V_1_1_DRAFT},
563+
{{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T},
564+
SpecificationVersion::V_1_1_DRAFT},
565+
{{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T},
566+
SpecificationVersion::V_1_1_DRAFT},
567+
{{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T},
568+
SpecificationVersion::V_1_1_DRAFT}}}}},
557569
{"tosa.max_pool2d",
558570
{{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
559571
{{Extension::fp8e4m3},

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,21 +270,22 @@ def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
270270
def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
271271
def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
272272
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
273+
def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>;
273274

274275
def Tosa_ExtensionAttr
275276
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
276277
Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
277278
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
278279
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
279-
Tosa_EXT_DYNAMIC
280+
Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP
280281
]> {
281282
let extraClassDeclaration = [{
282283
static llvm::SmallVector<Extension, 11> getAllValues() {
283284
return {
284285
Extension::int16, Extension::int4, Extension::bf16,
285286
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
286287
Extension::variable, Extension::controlflow, Extension::doubleround,
287-
Extension::inexactround, Extension::dynamic
288+
Extension::inexactround, Extension::dynamic, Extension::mxfp
288289
};
289290
}
290291
}];
@@ -437,7 +438,7 @@ def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
437438
}
438439

439440
//===----------------------------------------------------------------------===//
440-
// Iterable attributes.
441+
// Enum attributes.
441442
//===----------------------------------------------------------------------===//
442443
// Defined in `section 3. Enumerations` of the TOSA specification.
443444

@@ -463,6 +464,18 @@ def Tosa_RoundingModeAttr
463464
: Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
464465
[Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
465466

467+
def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
468+
469+
def Tosa_BlockSizeAttr
470+
: Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
471+
[Tosa_BLOCK_SIZE_32]> {
472+
let extraClassDeclaration = [{
473+
static uint32_t getBlockSizeValue(BlockSize blockSize) {
474+
return static_cast<uint32_t>(blockSize);
475+
}
476+
}];
477+
}
478+
466479

467480
//===----------------------------------------------------------------------===//
468481
// TOSA Interfaces.

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,40 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
347347
"operands attr-dict `:` functional-type(operands, results)";
348348
}
349349

350+
//===----------------------------------------------------------------------===//
351+
// Operator: matmul_t_block_scaled
352+
//===----------------------------------------------------------------------===//
353+
def Tosa_MatmulTBlockScaledOp : Tosa_InferShapedTypeOp<"matmul_t_block_scaled"> {
354+
let summary = "Performs two dimensional matrix multiplications using block scaled tensors.";
355+
356+
let description = [{
357+
Performs two dimensional matrix multiplications using block scaled tensors. The block
358+
dimension is always the the last dimension of the tensor, so the result is effectively
359+
a matrix multiply of A by the transposed B matrix. If the N dimension of input B is of
360+
size 1, the B matrix will be broadcast.
361+
}];
362+
363+
let arguments = (ins
364+
Tosa_MXFPDataTensor3D:$a_data,
365+
Tosa_MXFPScaleTensor3D:$a_scale,
366+
Tosa_MXFPDataTensor3D:$b_data,
367+
Tosa_MXFPScaleTensor3D:$b_scale,
368+
Tosa_BlockSizeAttr:$block_size
369+
);
370+
371+
let results = (outs
372+
Tosa_Tensor3D:$output_data
373+
);
374+
375+
let hasVerifier = 1;
376+
let hasCustomAssemblyFormat = 1;
377+
378+
list<Availability> availability = [
379+
Profile<[Tosa_PRO_FP]>,
380+
Extension<[Tosa_EXT_MXFP]>
381+
];
382+
}
383+
350384
//===----------------------------------------------------------------------===//
351385
// Operator: max_pool2d
352386
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class TosaProfileCompliance {
147147
case Extension::fp8e4m3:
148148
case Extension::fp8e5m2:
149149
case Extension::fft:
150+
case Extension::mxfp:
150151
return {Profile::pro_fp};
151152
case Extension::variable:
152153
case Extension::controlflow:

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
8484
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8585
"number">;
8686

87+
def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
88+
"micro-scaling format number">;
89+
def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;
90+
8791
//===----------------------------------------------------------------------===//
8892
// TOSA Tensor Conformance
8993
//===----------------------------------------------------------------------===//
@@ -187,6 +191,15 @@ def Tosa_Int32Tensor2D : AnyTypeOf<[
187191
def Tosa_TensorAtLeast1D : AnyTypeOf<[
188192
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
189193

194+
def Tosa_MXFPDataTensor3D : AnyTypeOf<[
195+
TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
196+
TosaTensorRankOf<[Tosa_MXFPNumber], [3]>
197+
]>;
198+
def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
199+
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
200+
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
201+
]>;
202+
190203
//===----------------------------------------------------------------------===//
191204
// Generic scalar, vector, or tensor of a particular type.
192205
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,96 @@
1212
namespace mlir {
1313
namespace tosa {
1414

15+
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
16+
return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
17+
}
18+
19+
TosaSpecificationVersion getMinVersion(const Profile &profile) {
20+
switch (profile) {
21+
case Profile::pro_int:
22+
case Profile::pro_fp:
23+
return TosaSpecificationVersion(1, 0);
24+
case Profile::none:
25+
return TosaSpecificationVersion(0, 0);
26+
}
27+
llvm_unreachable("Unknown TOSA profile");
28+
}
29+
30+
TosaSpecificationVersion getMinVersion(const Extension &extension) {
31+
switch (extension) {
32+
case Extension::int16:
33+
case Extension::int4:
34+
case Extension::bf16:
35+
case Extension::fp8e4m3:
36+
case Extension::fp8e5m2:
37+
case Extension::fft:
38+
case Extension::variable:
39+
case Extension::controlflow:
40+
case Extension::doubleround:
41+
case Extension::inexactround:
42+
case Extension::dynamic:
43+
return TosaSpecificationVersion(1, 0);
44+
case Extension::mxfp:
45+
return TosaSpecificationVersion(1, 1);
46+
case Extension::none:
47+
return TosaSpecificationVersion(0, 0);
48+
}
49+
llvm_unreachable("Unknown TOSA extension");
50+
}
51+
52+
TosaSpecificationVersion getMinVersion(const Level &level) {
53+
switch (level) {
54+
case Level::eightK:
55+
case Level::none:
56+
return TosaSpecificationVersion(1, 0);
57+
}
58+
llvm_unreachable("Unknown TOSA level");
59+
}
60+
61+
FailureOr<TargetEnv>
62+
TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr,
63+
Location targetEnvAttrLoc) {
64+
if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc)))
65+
return failure();
66+
67+
return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
68+
targetAttr.getProfiles(), targetAttr.getExtensions());
69+
}
70+
71+
LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
72+
Location targetAttrLoc) {
73+
TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion());
74+
75+
const auto isCompatibleWithTargetVersion =
76+
[&](const auto &targetEnum, Location targetAttrLoc,
77+
StringRef enumName) -> LogicalResult {
78+
const TosaSpecificationVersion minRequiredVersion =
79+
getMinVersion(targetEnum);
80+
if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion))
81+
return emitError(targetAttrLoc, enumName)
82+
<< " '" << stringifyEnum(targetEnum)
83+
<< "' is not compatible with the target version "
84+
<< stringifyVersion(targetVersion)
85+
<< ", minimum required version is "
86+
<< stringifyVersion(minRequiredVersion);
87+
return success();
88+
};
89+
90+
for (const auto &profile : targetAttr.getProfiles())
91+
if (failed(
92+
isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
93+
return failure();
94+
for (const auto &extension : targetAttr.getExtensions())
95+
if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
96+
"extension")))
97+
return failure();
98+
if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
99+
"level")))
100+
return failure();
101+
102+
return success();
103+
}
104+
15105
TargetEnvAttr lookupTargetEnv(Operation *op) {
16106
while (op) {
17107
op = SymbolTable::getNearestSymbolTable(op);
@@ -39,9 +129,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
39129
return getDefaultTargetEnv(op->getContext());
40130
}
41131

42-
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
43-
return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
44-
}
45-
46132
} // namespace tosa
47133
} // namespace mlir

0 commit comments

Comments
 (0)