diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 20c9097b51e6d..a38cf41a3e09b 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1229,37 +1229,50 @@ def Arith_ScalingExtFOp let summary = "Upcasts input floats using provided scales values following " "OCP MXFP Spec"; let description = [{ - This operation upcasts input floating-point values using provided scale - values. It expects both scales and the input operand to be of the same shape, - making the operation elementwise. Scales are usually calculated per block - following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. - - If scales are calculated per block where blockSize != 1, then scales may - require broadcasting to make this operation elementwise. For example, let's - say the input is of shape ``. Given blockSize != 1 and - assuming quantization happens on the last axis, the input can be reshaped to - ``. Scales will be calculated - per block on the last axis. Therefore, scales will be of shape - ``. Scales could also be of some other - shape as long as it is broadcast compatible with the input, e.g., - `<1 x 1 x ... (dimN/blockSize) x 1>`. - - In this example, before calling into `arith.scaling_extf`, scales must be - broadcasted to ``. Note - that there could be multiple quantization axes. Internally, - `arith.scaling_extf` would perform the following: + This operation upcasts input floating-point values using provided scale + values. It expects both scales and the input operand to be of the same shape, + making the operation elementwise. Scales are usually calculated per block + following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. - ``` - resultTy = get_type(result) - scaleTy = get_type(scale) - inputTy = get_type(input) - scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 - scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy - input.extf = arith.extf(input) : inputTy to resultTy - result = arith.mulf(scale.extf, input.extf) + If scales are calculated per block where blockSize != 1, then scales may + require broadcasting to make this operation elementwise. For example, let's + say the input is of shape ``. Given blockSize != 1 and + assuming quantization happens on the last axis, the input can be reshaped to + ``. Scales will be calculated + per block on the last axis. Therefore, scales will be of shape + ``. Scales could also be of some other + shape as long as it is broadcast compatible with the input, e.g., + `<1 x 1 x ... (dimN/blockSize) x 1>`. + + In this example, before calling into `arith.scaling_extf`, scales must be + broadcasted to ``. Note + that there could be multiple quantization axes. Internally, + `arith.scaling_extf` would perform the following: + + ```mlir + // Cast scale to result type. + %0 = arith.truncf %1 : f32 to f8E8M0FNU + %1 = arith.extf %0 : f8E8M0FNU to f16 + + // Cast input to result type. + %2 = arith.extf %3 : f4E2M1FN to f16 + + // Perform scaling + %3 = arith.mulf %2, %1 : f16 ``` It propagates NaN values. Therefore, if either scale or the input element contains NaN, then the output element value will also be a NaN. + + Example: + + ```mlir + // Upcast from f4E2M1FN to f32. + %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32 + + // Element-wise upcast with broadcast (blockSize = 32). + %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU> + %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16> + ``` }]; let hasVerifier = 1; let assemblyFormat = @@ -1397,14 +1410,27 @@ def Arith_ScalingTruncFOp that there could be multiple quantization axes. Internally, `arith.scaling_truncf` would perform the following: + ```mlir + // Cast scale to input type. + %0 = arith.truncf %1 : f32 to f8E8M0FNU + %1 = arith.extf %0 : f8E8M0FNU to f16 + + // Perform scaling. + %3 = arith.divf %2, %1 : f16 + + // Cast to result type. + %4 = arith.truncf %3 : f16 to f4E2M1FN ``` - scaleTy = get_type(scale) - inputTy = get_type(input) - resultTy = get_type(result) - scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 - scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy - result = arith.divf(input, scale.extf) - result.cast = arith.truncf(result, resultTy) + + Example: + + ```mlir + // Downcast from f32 to f4E2M1FN. + %a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN + + // Element-wise downcast with broadcast (blockSize = 32). + %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU> + %h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN> ``` }]; let hasVerifier = 1;