Skip to content

Commit 8f9e69d

Browse files
lhutton1tatwaichong
authored andcommitted
[mlir][tosa] Add support for mxint8 type in mxfp operations (llvm#163642)
This commit adds support for the OCP-MX INT8 type. This includes the following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED, CAST_TO_BLOCK_SCALED and CONST. The support is added via a custom TOSA type "!tosa.mxint8" due to the fact it is not yet a builtin type in mlir. This may change in the future, depending on how this type is used by other frameworks/dialects. Conversions to/from this type have not yet been implemented for the same reasoning. Co-authored-by: Tat Wai Chong <[email protected]>
1 parent 92e1ce3 commit 8f9e69d

File tree

9 files changed

+121
-23
lines changed

9 files changed

+121
-23
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,8 @@ extensionComplianceMap = {
572572
{{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T},
573573
SpecificationVersion::V_1_1_DRAFT},
574574
{{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T},
575+
SpecificationVersion::V_1_1_DRAFT},
576+
{{mxint8T, fp8ue8m0T, mxint8T, fp8ue8m0T, fp32T},
575577
SpecificationVersion::V_1_1_DRAFT}}}}},
576578
{"tosa.max_pool2d",
577579
{{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
@@ -870,27 +872,31 @@ extensionComplianceMap = {
870872
{{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
871873
{{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
872874
{{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
873-
{{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
875+
{{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
876+
{{mxint8T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
874877
allOf},
875878
{{Extension::mxfp},
876879
{{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
877880
{{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
878881
{{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
879882
{{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
880-
{{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
883+
{{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
884+
{{mxint8T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
881885
{"tosa.cast_to_block_scaled",
882886
{{{Extension::mxfp},
883887
{{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
884888
{{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
885889
{{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
886890
{{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
887891
{{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
888-
{{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
892+
{{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
893+
{{fp32T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
889894
{{Extension::bf16, Extension::mxfp},
890895
{{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
891896
{{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
892897
{{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
893-
{{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
898+
{{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
899+
{{bf16T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
894900
allOf}}},
895901
{"tosa.rescale",
896902
{{{Extension::int16},
@@ -908,7 +914,8 @@ extensionComplianceMap = {
908914
{{{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
909915
{{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
910916
{{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
911-
{{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}}}}},
917+
{{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
918+
{{mxint8T}, SpecificationVersion::V_1_1_DRAFT}}}}},
912919
{"tosa.identity",
913920
{{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
914921
{{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
179179
// returns type of variable op
180180
RankedTensorType getVariableType(VariableOp variableOp);
181181

182+
// Returns the bitwidth of a TOSA tensor element type
183+
unsigned getBitWidth(Type type);
184+
182185
} // namespace tosa
183186
} // namespace mlir
184187

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ProfileInfoDepot {
7070

7171
private:
7272
TypeInfo convertTypeToInfo(Type type) {
73-
return {type.getTypeID(), type.getIntOrFloatBitWidth()};
73+
return {type.getTypeID(), tosa::getBitWidth(type)};
7474
}
7575

7676
TypeInfo convertValueToInfo(Value value) {

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
2222
// Tosa Type Definitions.
2323
//===----------------------------------------------------------------------===//
2424

25+
// The base class for Tosa dialect types.
26+
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
27+
: TypeDef<Tosa_Dialect, name, traits> {
28+
let mnemonic = typeMnemonic;
29+
}
30+
2531
// The base class of a quantized type.
2632
// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
2733
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
@@ -78,13 +84,26 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
7884
Tosa_QuantizedType<"int16", [16, 0], 1>,
7985
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
8086

87+
//===----------------------------------------------------------------------===//
88+
// Custom TOSA element types.
89+
//===----------------------------------------------------------------------===//
90+
91+
// MLIR doesn't have a builtin type for mxint8 yet. For now declared it as a
92+
// custom TOSA type. This may be changed in the future.
93+
def Tosa_MXInt8 : Tosa_Type<"mxint8", "mxint8"> {
94+
let summary = "INT8 type as defined by OCP-MX";
95+
let description = [{
96+
8-bit integer format with an implicit 1/64 scale defined by OCP-MX.
97+
}];
98+
}
99+
81100
//===----------------------------------------------------------------------===//
82101
// Multi-category types.
83102
//===----------------------------------------------------------------------===//
84-
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
103+
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat, Tosa_MXInt8],
85104
"number">;
86105

87-
def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
106+
def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN, Tosa_MXInt8],
88107
"micro-scaling format number">;
89108
def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;
90109

@@ -265,16 +284,6 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
265284
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
266285
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
267286

268-
//===----------------------------------------------------------------------===//
269-
// Tosa Type Definitions.
270-
//===----------------------------------------------------------------------===//
271-
272-
// The base class for Tosa dialect types.
273-
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
274-
: TypeDef<Tosa_Dialect, name, traits> {
275-
let mnemonic = typeMnemonic;
276-
}
277-
278287
//===----------------------------------------------------------------------===//
279288
// ShapeType
280289
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,12 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
606606
return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
607607
}
608608

609+
unsigned mlir::tosa::getBitWidth(Type type) {
610+
if (dyn_cast<tosa::mxint8Type>(type))
611+
return 8;
612+
return type.getIntOrFloatBitWidth();
613+
}
614+
609615
//===----------------------------------------------------------------------===//
610616
// TOSA Operator Verifiers.
611617
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ TosaProfileCompliance::TosaProfileCompliance() {
3131
const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
3232
const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
3333
const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
34+
const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
3435

3536
// The profile-based compliance content below is auto-generated by a script
3637
// in https://git.mlplatform.org/tosa/specification.git
@@ -625,6 +626,8 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
625626
return {"fp4e2m1"};
626627
} else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
627628
return {"fp8e8m0"};
629+
} else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) {
630+
return {"mxint8"};
628631
}
629632
llvm_unreachable("unknown type");
630633
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
693693
<< " shape dimension cannot be dynamic";
694694
}
695695

696-
int64_t element_bits = type.getElementTypeBitWidth();
696+
int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
697697
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
698698
int64_t size = element_bytes * type.getNumElements();
699699

@@ -1217,9 +1217,10 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
12171217
return true;
12181218
}
12191219
}
1220-
} else if (mlir::isa<tosa::shapeType>(type)) {
1220+
} else if (isa<tosa::shapeType>(type))
1221+
return true;
1222+
else if (isa<tosa::mxint8Type>(type))
12211223
return true;
1222-
}
12231224
return false;
12241225
}
12251226

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,13 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>,
12691269
return %0 : tensor<4x8x16xf32>
12701270
}
12711271

1272+
// -----
1273+
// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
1274+
func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
1275+
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
1276+
return %0 : tensor<4x8x16xf32>
1277+
}
1278+
12721279
// -----
12731280
// CHECK-LABEL: test_cast_from_block_scaled_static
12741281
func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
@@ -1296,3 +1303,17 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*
12961303
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
12971304
return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
12981305
}
1306+
1307+
// -----
1308+
// CHECK-LABEL: test_cast_to_block_scaled_mxint8
1309+
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
1310+
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
1311+
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
1312+
}
1313+
1314+
// -----
1315+
// CHECK-LABEL: test_const_mxint8
1316+
func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
1317+
%0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
1318+
return %0 : tensor<2x!tosa.mxint8>
1319+
}

mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ func.func @test_argmax_int64(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64
3838
// -----
3939

4040
// CHECK-LABEL: test_const_i64
41-
func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
41+
func.func @test_const_i64() -> tensor<4xi64> {
4242
%0 = "tosa.const"() {values = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
4343
return %0 : tensor<4xi64>
4444
}
4545

4646
// -----
4747

4848
// CHECK-LABEL: test_const_fp6e3m2
49-
func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
49+
func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> {
5050
%0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
5151
return %0 : tensor<4xf6E3M2FN>
5252
}
@@ -82,3 +82,51 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
8282
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
8383
return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
8484
}
85+
86+
// -----
87+
88+
// CHECK-LABEL: test_cast_to_block_scaled_mxint8
89+
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
90+
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
91+
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
92+
}
93+
94+
// -----
95+
96+
// CHECK-LABEL: test_const_fp6e3m2
97+
func.func @test_const_fp6e3m2() -> tensor<4xf6E3M2FN> {
98+
%0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
99+
return %0 : tensor<4xf6E3M2FN>
100+
}
101+
102+
// -----
103+
104+
// CHECK-LABEL: test_const_mxint8
105+
func.func @test_const_mxint8() -> tensor<2x!tosa.mxint8> {
106+
%0 = "tosa.const"() {values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
107+
return %0 : tensor<2x!tosa.mxint8>
108+
}
109+
110+
// -----
111+
112+
// CHECK-LABEL: test_cast_f4e2m1
113+
func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
114+
%0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
115+
return %0 : tensor<13x21x3xbf16>
116+
}
117+
118+
// -----
119+
120+
// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
121+
func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
122+
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
123+
return %0 : tensor<4x8x16xf32>
124+
}
125+
126+
// -----
127+
128+
// CHECK-LABEL: test_cast_to_block_scaled_mxint8
129+
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
130+
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
131+
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
132+
}

0 commit comments

Comments
 (0)