Skip to content

Commit ae99b99

Browse files
authored
Merge pull request #479 from Xilinx/jrickert.bfpqdqbf16
Add BF16 support to AMDQuarkBFPQuantizeDequantizeOp
2 parents d6a01d2 + e3ebc5f commit ae99b99

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/Dialect/ONNX/AMDQuarkOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ def AMDQuarkBFPQuantizeDequantizeOp: ONNX_Op<"AMDQuarkBFPQuantizeDequantizeOp",
2727
MicroeXponents (MX) extends the concept of BFP by introducing two levels of exponents: shared exponents for entire blocks and micro exponents for finer-grained sub-blocks. This two-level approach enables more precise scaling of individual elements within a block, reducing quantization error and improving the representational range. The paper https://arxiv.org/abs/2302.08007 introduces three specific formats: MX4, MX6 and MX9, which have different bits of mantissa.
2828

2929
This operator converts floating-point values (typically 32-bit floating-point numbers) into BFP or MX values, then convert them back. It approximates the Quantize-Dequantize process and introduces quantization errors.
30+
31+
Support for BF16 is an AMD extension in ONNX-MLIR to https://quark.docs.amd.com/latest/onnx/custom_operators/BFPQuantizeDequantize.html.
3032
}];
3133

32-
let arguments = (ins TensorOf<[F32]>:$X,
34+
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[BF16]>]>:$X,
3335
DefaultValuedStrAttr<StrAttr, "to_bfp">:$bfp_method,
3436
DefaultValuedAttr<SI64Attr, "1">:$axis,
3537
DefaultValuedAttr<SI64Attr, "16">:$bit_width,
@@ -38,7 +40,7 @@ def AMDQuarkBFPQuantizeDequantizeOp: ONNX_Op<"AMDQuarkBFPQuantizeDequantizeOp",
3840
DefaultValuedAttr<SI64Attr, "2">:$sub_block_size,
3941
DefaultValuedAttr<SI64Attr, "1">:$sub_block_shift_bits
4042
);
41-
let results = (outs TensorOf<[F32]>:$Y);
43+
let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[BF16]>]>:$Y);
4244

4345
let hasVerifier = 1;
4446

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4567,4 +4567,12 @@ func.func @test_bfp_quant_dequant(%arg0: tensor<16x32xf32>) -> tensor<*xf32> {
45674567
}
45684568
// CHECK-LABEL: func.func @test_bfp_quant_dequant
45694569
// CHECK: "onnx.AMDQuarkBFPQuantizeDequantizeOp"
4570-
// CHECK-SAME: (tensor<16x32xf32>) -> tensor<16x32xf32>
4570+
// CHECK-SAME: (tensor<16x32xf32>) -> tensor<16x32xf32>
4571+
4572+
func.func @test_bfp_quant_dequant_bf16(%arg0: tensor<16x32xbf16>) -> tensor<*xbf16> {
4573+
%0 = "onnx.AMDQuarkBFPQuantizeDequantizeOp"(%arg0) : (tensor<16x32xbf16>) -> tensor<*xbf16>
4574+
return %0 : tensor<*xbf16>
4575+
}
4576+
// CHECK-LABEL: func.func @test_bfp_quant_dequant_bf16
4577+
// CHECK: "onnx.AMDQuarkBFPQuantizeDequantizeOp"
4578+
// CHECK-SAME: (tensor<16x32xbf16>) -> tensor<16x32xbf16>

0 commit comments

Comments
 (0)