|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +//===-- AMDQuarkOps.td -- AMD Quark Ops -*- tablegen -===// |
| 4 | +// |
| 5 | +// Copyright 2025 Advanced Micro Devices, Inc. or its affiliates |
| 6 | +// |
| 7 | +// ============================================================================= |
| 8 | +// |
| 9 | +// Defines Ops from AMD's Quark quantizer, version 0.10 |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +include "mlir/Interfaces/CallInterfaces.td" |
| 14 | +include "mlir/IR/SymbolInterfaces.td" |
| 15 | +include "src/IR/AttrBase.td" |
| 16 | + |
| 17 | +//===----------------------------------------------------------------------===// |
| 18 | +// BFPQuantizeDequantizeOp |
| 19 | +def AMDQuarkBFPQuantizeDequantizeOp: ONNX_Op<"AMDQuarkBFPQuantizeDequantizeOp", |
| 20 | + [Pure, OpVersionTrait<1>, |
| 21 | + DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, |
| 22 | + DeclareOpInterfaceMethods<ShapeHelperOpInterface>, SameOperandsAndResultElementType]> { |
| 23 | + let summary = "BFPQuantizeDequantize"; |
| 24 | + let description = [{ |
| 25 | + Block Floating Point (BFP) groups numbers (e.g., tensors, arrays) into blocks, where each block shares a common exponent, and the values in the block are represented with individual mantissas (and the sign bit). This approach offers the performance and speed of 8-bit operations while bringing the precision closer to 16-bit operations. |
| 26 | + |
| 27 | + MicroeXponents (MX) extends the concept of BFP by introducing two levels of exponents: shared exponents for entire blocks and micro exponents for finer-grained sub-blocks. This two-level approach enables more precise scaling of individual elements within a block, reducing quantization error and improving the representational range. The paper https://arxiv.org/abs/2302.08007 introduces three specific formats: MX4, MX6 and MX9, which have different bits of mantissa. |
| 28 | + |
| 29 | + This operator converts floating-point values (typically 32-bit floating-point numbers) into BFP or MX values, then convert them back. It approximates the Quantize-Dequantize process and introduces quantization errors. |
| 30 | + }]; |
| 31 | + |
| 32 | + let arguments = (ins TensorOf<[F32]>:$X, |
| 33 | + DefaultValuedStrAttr<StrAttr, "to_bfp">:$bfp_method, |
| 34 | + DefaultValuedAttr<SI64Attr, "1">:$axis, |
| 35 | + DefaultValuedAttr<SI64Attr, "16">:$bit_width, |
| 36 | + DefaultValuedAttr<SI64Attr, "8">:$block_size, |
| 37 | + DefaultValuedAttr<SI64Attr, "0">:$rounding_mode, |
| 38 | + DefaultValuedAttr<SI64Attr, "2">:$sub_block_size, |
| 39 | + DefaultValuedAttr<SI64Attr, "1">:$sub_block_shift_bits |
| 40 | + ); |
| 41 | + let results = (outs TensorOf<[F32]>:$Y); |
| 42 | + |
| 43 | + let hasVerifier = 1; |
| 44 | + |
| 45 | + let extraClassDeclaration = [{ |
| 46 | + static int getNumberOfOperands() { return 1; } |
| 47 | + static int getNumberOfResults() { return 1; } |
| 48 | + static std::vector<int> getTypeMap() { return {30}; } // Same result element type as operand |
| 49 | + [[nodiscard]] bool isBFP16(bool ignoreAxis = false); |
| 50 | + [[nodiscard]] bool isMX4(bool ignoreAxis = false); |
| 51 | + [[nodiscard]] bool isMX6(bool ignoreAxis = false); |
| 52 | + [[nodiscard]] bool isMX9(bool ignoreAxis = false); |
| 53 | + }]; |
| 54 | + |
| 55 | + let extraClassDefinition = [{ |
| 56 | + onnx_mlir::ONNXOpShapeHelper * AMDQuarkBFPQuantizeDequantizeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef<mlir::Value> oper, |
| 57 | + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { |
| 58 | + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::AMDQuarkBFPQuantizeDequantizeOpShapeHelper(op, oper, ieb, scope); |
| 59 | + assert(sh && "failed to allocate shape helper"); |
| 60 | + return sh; |
| 61 | + } |
| 62 | + }]; |
| 63 | +} |
| 64 | + |
0 commit comments