-
Couldn't load subscription status.
- Fork 15k
[mlir][tosa] Add ext-mxfp support for const and cast ops #163641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This commit allows const and cast ops with MXFP datatypes through the validation pass when specification version 1.1.draft is selected. Note: it doesn't include support for the mxint8 datatype. This will be added in a separate commit. Note: this commit adds support as defined in the spec in https://review.mlplatform.org/c/tosa/specification/+/15362. EXT_MXFP extension is considered experimental and subject to breaking change. Change-Id: Idd0477bc947ade524b0fb0213cc7e8d4f892ddab
106a171 to
faf61e3
Compare
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThis commit allows const and cast ops with MXFP datatypes through the validation pass when specification version 1.1.draft is selected. Note: it doesn't include support for the mxint8 datatype. This will be added in a separate commit. Note: this commit adds support as defined in the spec in arm/tosa-specification@063846a. EXT_MXFP extension is considered experimental and subject to breaking change. Note: This PR relies on Full diff: https://github.com/llvm/llvm-project/pull/163641.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 06e4ee0c4176d..6e78b75f37d10 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -855,7 +855,15 @@ extensionComplianceMap = {
{{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0},
{{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
{{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
- {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
+ {{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
+ {{Extension::bf16, Extension::mxfp},
+ {{{fp4e2m1T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
+ {{bf16T, fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT}},
+ allOf}}},
{"tosa.rescale",
{{{Extension::int16},
{{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -867,7 +875,12 @@ extensionComplianceMap = {
{{Extension::int64}, {{{i64T}, SpecificationVersion::V_1_1_DRAFT}}},
{{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}},
- {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}},
+ {{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}},
+ {{Extension::mxfp},
+ {{{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.identity",
{{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
{{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddd8cba5f9dd5..0e3df21f43804 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2462,7 +2462,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
list<Availability> availability = [
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_INT64]>,
+ Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_MXFP, Tosa_EXT_INT64]>,
];
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
@@ -2578,7 +2578,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
list<Availability> availability = [
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_INT64]>,
+ Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_MXFP, Tosa_EXT_INT64]>,
];
let hasFolder = 1;
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 58ef0d3e5ae59..c138ac9bab2c4 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -606,7 +606,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
// CHECK-LABEL: cast
func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, int64] ]
+ // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, mxfp, int64] ]
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -626,7 +626,7 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439
// CHECK-LABEL: test_const
func.func @test_const(%arg0 : index) -> tensor<4xi32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
- // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16, int64] ]
+ // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16, mxfp, int64] ]
%0 = "tosa.const"() {values = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
return %0 : tensor<4xi32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index a214183b55876..ab048afd1ca0b 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -555,3 +555,17 @@ func.func @test_argmax_int64(%arg0: tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64
%0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x13x13x5xf32>) -> tensor<1x13x13xi64>
return %0 : tensor<1x13x13xi64>
}
+
+// -----
+func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
+ // expected-error@+1 {{'tosa.const' op illegal: requires [mxfp] but not enabled in target}}
+ %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
+ return %0 : tensor<4xf6E3M2FN>
+}
+
+// -----
+func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
+ // expected-error@+1 {{'tosa.cast' op illegal: requires all of [bf16, mxfp] but not enabled in target}}
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
+ return %0 : tensor<13x21x3xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 2301d2febb5c3..926e7f2798c23 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -42,3 +42,19 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
%0 = "tosa.const"() {values = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
return %0 : tensor<4xi64>
}
+
+// -----
+
+// CHECK-LABEL: test_const_fp6e3m2
+func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
+ %0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
+ return %0 : tensor<4xf6E3M2FN>
+}
+
+// -----
+
+// CHECK-LABEL: test_cast_f4e2m1
+func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
+ return %0 : tensor<13x21x3xbf16>
+}
|
This commit allows const and cast ops with MXFP datatypes through the validation pass when specification version 1.1.draft is selected. Note: it doesn't include support for the mxint8 datatype. This will be added in a separate commit. Note: this commit adds support as defined in the spec in arm/tosa-specification@063846a. EXT_MXFP extension is considered experimental and subject to breaking change.
This commit allows const and cast ops with MXFP datatypes through the validation pass when specification version 1.1.draft is selected. Note: it doesn't include support for the mxint8 datatype. This will be added in a separate commit. Note: this commit adds support as defined in the spec in arm/tosa-specification@063846a. EXT_MXFP extension is considered experimental and subject to breaking change.
This commit allows const and cast ops with MXFP datatypes through the validation pass when specification version 1.1.draft is selected.
Note: it doesn't include support for the mxint8 datatype. This will be added in a separate commit.
Note: this commit adds support as defined in the spec in arm/tosa-specification@063846a. EXT_MXFP extension is considered experimental and subject to breaking change.
Note: This PR relies on #156425, #163433 and #163436 so also contains their contents.