Skip to content

Commit bc9ebf3

Browse files
committed
Add support for AMD Quarks BFPQuantizeDequantizeOp
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 99c057f commit bc9ebf3

File tree

8 files changed

+225
-1
lines changed

8 files changed

+225
-1
lines changed

docs/Dialects/onnx.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,45 @@
11
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
2+
### `onnx.AMDQuarkBFPQuantizeDequantizeOp` (AMDQuarkBFPQuantizeDequantizeOp)
3+
4+
_BFPQuantizeDequantize_
5+
6+
Block Floating Point (BFP) groups numbers (e.g., tensors, arrays) into blocks, where each block shares a common exponent, and the values in the block are represented with individual mantissas (and the sign bit). This approach offers the performance and speed of 8-bit operations while bringing the precision closer to 16-bit operations.
7+
8+
MicroeXponents (MX) extends the concept of BFP by introducing two levels of exponents: shared exponents for entire blocks and micro exponents for finer-grained sub-blocks. This two-level approach enables more precise scaling of individual elements within a block, reducing quantization error and improving the representational range. The paper https://arxiv.org/abs/2302.08007 introduces three specific formats: MX4, MX6 and MX9, which have different bits of mantissa.
9+
10+
This operator converts floating-point values (typically 32-bit floating-point numbers) into BFP or MX values, then convert them back. It approximates the Quantize-Dequantize process and introduces quantization errors.
11+
12+
Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>`, `SameOperandsAndResultElementType`
13+
14+
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
15+
16+
Effects: `MemoryEffects::Effect{}`
17+
18+
#### Attributes:
19+
20+
<table>
21+
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
22+
<tr><td><code>bfp_method</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
23+
<tr><td><code>axis</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
24+
<tr><td><code>bit_width</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
25+
<tr><td><code>block_size</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
26+
<tr><td><code>rounding_mode</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
27+
<tr><td><code>sub_block_size</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
28+
<tr><td><code>sub_block_shift_bits</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
29+
</table>
30+
31+
#### Operands:
32+
33+
| Operand | Description |
34+
| :-----: | ----------- |
35+
| `X` | tensor of 32-bit float values
36+
37+
#### Results:
38+
39+
| Result | Description |
40+
| :----: | ----------- |
41+
| `Y` | tensor of 32-bit float values
42+
243
### `onnx.Abs` (ONNXAbsOp)
344

445
_ONNX Abs operation_

src/Dialect/ONNX/AMDQuarkOps.td

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
//===-- AMDQuarkOps.td -- AMD Quark Ops -*- tablegen -===//
4+
//
5+
// Copyright 2025 Advanced Micro Devices, Inc. or its affiliates
6+
//
7+
// =============================================================================
8+
//
9+
// Defines Ops from AMD's Quark quantizer, version 0.10
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
include "mlir/Interfaces/CallInterfaces.td"
14+
include "mlir/IR/SymbolInterfaces.td"
15+
include "src/IR/AttrBase.td"
16+
17+
//===----------------------------------------------------------------------===//
18+
// BFPQuantizeDequantizeOp
19+
def AMDQuarkBFPQuantizeDequantizeOp: ONNX_Op<"AMDQuarkBFPQuantizeDequantizeOp",
20+
[Pure, OpVersionTrait<1>,
21+
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
22+
DeclareOpInterfaceMethods<ShapeHelperOpInterface>, SameOperandsAndResultElementType]> {
23+
let summary = "BFPQuantizeDequantize";
24+
let description = [{
25+
Block Floating Point (BFP) groups numbers (e.g., tensors, arrays) into blocks, where each block shares a common exponent, and the values in the block are represented with individual mantissas (and the sign bit). This approach offers the performance and speed of 8-bit operations while bringing the precision closer to 16-bit operations.
26+
27+
MicroeXponents (MX) extends the concept of BFP by introducing two levels of exponents: shared exponents for entire blocks and micro exponents for finer-grained sub-blocks. This two-level approach enables more precise scaling of individual elements within a block, reducing quantization error and improving the representational range. The paper https://arxiv.org/abs/2302.08007 introduces three specific formats: MX4, MX6 and MX9, which have different bits of mantissa.
28+
29+
This operator converts floating-point values (typically 32-bit floating-point numbers) into BFP or MX values, then convert them back. It approximates the Quantize-Dequantize process and introduces quantization errors.
30+
}];
31+
32+
let arguments = (ins TensorOf<[F32]>:$X,
33+
DefaultValuedStrAttr<StrAttr, "to_bfp">:$bfp_method,
34+
DefaultValuedAttr<SI64Attr, "1">:$axis,
35+
DefaultValuedAttr<SI64Attr, "16">:$bit_width,
36+
DefaultValuedAttr<SI64Attr, "8">:$block_size,
37+
DefaultValuedAttr<SI64Attr, "0">:$rounding_mode,
38+
DefaultValuedAttr<SI64Attr, "2">:$sub_block_size,
39+
DefaultValuedAttr<SI64Attr, "1">:$sub_block_shift_bits
40+
);
41+
let results = (outs TensorOf<[F32]>:$Y);
42+
43+
let hasVerifier = 1;
44+
45+
let extraClassDeclaration = [{
46+
static int getNumberOfOperands() { return 1; }
47+
static int getNumberOfResults() { return 1; }
48+
static std::vector<int> getTypeMap() { return {30}; } // Same result element type as operand
49+
[[nodiscard]] bool isBFP16(bool ignoreAxis = false);
50+
[[nodiscard]] bool isMX4(bool ignoreAxis = false);
51+
[[nodiscard]] bool isMX6(bool ignoreAxis = false);
52+
[[nodiscard]] bool isMX9(bool ignoreAxis = false);
53+
}];
54+
55+
let extraClassDefinition = [{
56+
onnx_mlir::ONNXOpShapeHelper * AMDQuarkBFPQuantizeDequantizeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef<mlir::Value> oper,
57+
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
58+
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::AMDQuarkBFPQuantizeDequantizeOpShapeHelper(op, oper, ieb, scope);
59+
assert(sh && "failed to allocate shape helper");
60+
return sh;
61+
}
62+
}];
63+
}
64+

src/Dialect/ONNX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_onnx_mlir_library(OMONNXOps
2727
ONNXTypes.cpp
2828

2929
# Support for shape inference and verifiers
30+
ONNXOps/Additional/AMDQuark.cpp
3031
ONNXOps/Additional/ConcatShapeTranspose.cpp
3132
ONNXOps/Additional/Custom.cpp
3233
ONNXOps/Additional/Dim.cpp

src/Dialect/ONNX/ONNX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,5 +257,6 @@ class OpVersionTrait<int version>
257257
include "mlir/Interfaces/SideEffectInterfaces.td"
258258
include "src/Dialect/ONNX/ONNXOps.td.inc"
259259
include "src/Dialect/ONNX/AdditionalONNXOps.td"
260+
include "src/Dialect/ONNX/AMDQuarkOps.td"
260261

261262
#endif // ONNX_OPS
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
//===------------------ AMDQuark.cpp - AMD Quark custom ops ---------------===//
6+
//
7+
// Copyright 2025 Advanced Micro Devices, Inc. or its affiliates
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
12+
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
13+
14+
using namespace mlir;
15+
using namespace onnx_mlir;
16+
17+
LogicalResult AMDQuarkBFPQuantizeDequantizeOp::verify() {
18+
// Verify that the quantization mode is valid.
19+
const auto method = getBfpMethod();
20+
if (method != "to_bfp" && method != "to_bfp_prime") {
21+
return emitOpError("invalid bfp_method attribute value: " + method +
22+
". Supported values are 'to_bfp' and 'to_bfp_prime'.");
23+
}
24+
const int64_t roundingMode = getRoundingMode();
25+
if (roundingMode < 0 || roundingMode > 3) {
26+
return emitOpError(
27+
"invalid rounding_mode attribute value: " +
28+
std::to_string(roundingMode) +
29+
". Supported values are 0 for rounding half away from zero, 1 for "
30+
"rounding half upward and 2 for rounding half to even.");
31+
}
32+
33+
return success();
34+
}
35+
36+
namespace {
37+
struct KnownConfig {
38+
StringRef method;
39+
int64_t axis;
40+
int64_t bit_width;
41+
int64_t block_size;
42+
int64_t rounding_mode;
43+
int64_t sub_block_size;
44+
int64_t sub_block_shift_bits;
45+
};
46+
47+
[[nodiscard]] bool isKnownConfig(AMDQuarkBFPQuantizeDequantizeOp *op,
48+
const KnownConfig &config, bool ignoreAxis) {
49+
if (op->getBfpMethod() != config.method)
50+
return false;
51+
if (!ignoreAxis && op->getAxis() != config.axis)
52+
return false;
53+
if (op->getBitWidth() != config.bit_width)
54+
return false;
55+
if (op->getBlockSize() != config.block_size)
56+
return false;
57+
if (op->getRoundingMode() != config.rounding_mode)
58+
return false;
59+
if (op->getBfpMethod() == "to_bfp_prime") {
60+
if (op->getSubBlockSize() != config.sub_block_size)
61+
return false;
62+
if (op->getSubBlockShiftBits() != config.sub_block_shift_bits)
63+
return false;
64+
}
65+
return true;
66+
}
67+
} // namespace
68+
69+
bool AMDQuarkBFPQuantizeDequantizeOp::isBFP16(bool ignoreAxis) {
70+
return isKnownConfig(this, {"to_bfp", 1, 16, 8, 2, 0, 0}, ignoreAxis);
71+
}
72+
bool AMDQuarkBFPQuantizeDequantizeOp::isMX4(bool ignoreAxis) {
73+
return isKnownConfig(this, {"to_bfp_prime", 1, 11, 16, 2, 2, 1}, ignoreAxis);
74+
}
75+
bool AMDQuarkBFPQuantizeDequantizeOp::isMX6(bool ignoreAxis) {
76+
return isKnownConfig(this, {"to_bfp_prime", 1, 13, 16, 2, 2, 1}, ignoreAxis);
77+
}
78+
bool AMDQuarkBFPQuantizeDequantizeOp::isMX9(bool ignoreAxis) {
79+
return isKnownConfig(this, {"to_bfp_prime", 1, 16, 16, 2, 2, 1}, ignoreAxis);
80+
}
81+
82+
LogicalResult AMDQuarkBFPQuantizeDequantizeOp::inferShapes(
83+
std::function<void(Region &)> /*doShapeInference*/) {
84+
return inferShapeForUnaryOps(this->getOperation());
85+
}

src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ using ONNXTanOpShapeHelper = ONNXUnaryOpShapeHelper;
411411
using ONNXTanhOpShapeHelper = ONNXUnaryOpShapeHelper;
412412
using ONNXThresholdedReluOpShapeHelper = ONNXUnaryOpShapeHelper;
413413
using ONNXTriluOpShapeHelper = ONNXUnaryOpShapeHelper;
414+
using AMDQuarkBFPQuantizeDequantizeOpShapeHelper = ONNXUnaryOpShapeHelper;
414415
// clang-format on
415416

416417
//===----------------------------------------------------------------------===//

test/mlir/onnx/invalid.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,4 +1102,21 @@ func.func @test_attention_bad_kv_num_heads(%q: tensor<1x128x3072xf32>, %k: tenso
11021102
// expected-error @+1 {{onnx.Attention: operand '<block argument> of type 'tensor<1x128x1536xf32>' at index: 1' has dimension at index 2 with value 1536, value should be divisible by 15}}
11031103
%out, %present_k, %present_v, %qk_out = "onnx.Attention"(%q, %k, %v, %none, %none, %none) {q_num_heads = 32: si64, kv_num_heads = 15: si64} : (tensor<1x128x3072xf32>, tensor<1x128x1536xf32>, tensor<1x128x768xf32>, none, none, none) -> (tensor<*xf32>, none, none, none)
11041104
return %out : tensor<*xf32>
1105+
}
1106+
1107+
1108+
// -----
1109+
1110+
func.func @test_bfp_quant_dequant_wrong_method(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> {
1111+
// expected-error @+1 {{'onnx.AMDQuarkBFPQuantizeDequantizeOp' op invalid bfp_method attribute value: from_bfp. Supported values are 'to_bfp' and 'to_bfp_prime'.}}
1112+
%0 = "onnx.AMDQuarkBFPQuantizeDequantizeOp"(%arg0) { bfp_method = "from_bfp"} : (tensor<16x32xf32>) -> tensor<16x32xf32>
1113+
return %0 : tensor<16x32xf32>
1114+
}
1115+
1116+
// -----
1117+
1118+
func.func @test_bfp_quant_dequant_wrong_rounding_mode(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> {
1119+
// expected-error @+1 {{'onnx.AMDQuarkBFPQuantizeDequantizeOp' op invalid rounding_mode attribute value: 4. Supported values are 0 for rounding half away from zero, 1 for rounding half upward and 2 for rounding half to even.}}
1120+
%0 = "onnx.AMDQuarkBFPQuantizeDequantizeOp"(%arg0) { rounding_mode = 4: si64 } : (tensor<16x32xf32>) -> tensor<16x32xf32>
1121+
return %0 : tensor<16x32xf32>
11051122
}

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4553,4 +4553,18 @@ func.func @test_attention_3d_q_4d_kv(%q: tensor<1x128x3072xf32>, %k: tensor<1x16
45534553
}
45544554
// CHECK-LABEL: func.func @test_attention_3d_q_4d_kv
45554555
// CHECK: "onnx.Attention"
4556-
// CHECK-SAME: (tensor<1x128x3072xf32>, tensor<1x16x128x96xf32>, tensor<1x16x128x48xf32>, none, none, none) -> (tensor<1x128x1536xf32>
4556+
// CHECK-SAME: (tensor<1x128x3072xf32>, tensor<1x16x128x96xf32>, tensor<1x16x128x48xf32>, none, none, none) -> (tensor<1x128x1536xf32>
4557+
4558+
// -----
4559+
4560+
//===----------------------------------------------------------------------===//
4561+
/// Test shape inference for amd.quark.BFPQuantizeDequantize
4562+
//===----------------------------------------------------------------------===//
4563+
4564+
func.func @test_bfp_quant_dequant(%arg0: tensor<16x32xf32>) -> tensor<*xf32> {
4565+
%0 = "onnx.AMDQuarkBFPQuantizeDequantizeOp"(%arg0) : (tensor<16x32xf32>) -> tensor<*xf32>
4566+
return %0 : tensor<*xf32>
4567+
}
4568+
// CHECK-LABEL: func.func @test_bfp_quant_dequant
4569+
// CHECK: "onnx.AMDQuarkBFPQuantizeDequantizeOp"
4570+
// CHECK-SAME: (tensor<16x32xf32>) -> tensor<16x32xf32>

0 commit comments

Comments
 (0)