@@ -1229,25 +1229,25 @@ def Arith_ScalingExtFOp
12291229 let summary = "Upcasts input floats using provided scales values following "
12301230 "OCP MXFP Spec";
12311231 let description = [{
1232- This operation upcasts input floating-point values using provided scale
1233- values. It expects both scales and the input operand to be of the same shape,
1234- making the operation elementwise. Scales are usually calculated per block
1235- following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1236-
1237- If scales are calculated per block where blockSize != 1, then scales may
1238- require broadcasting to make this operation elementwise. For example, let's
1239- say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240- assuming quantization happens on the last axis, the input can be reshaped to
1241- `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242- per block on the last axis. Therefore, scales will be of shape
1243- `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244- shape as long as it is broadcast compatible with the input, e.g.,
1245- `<1 x 1 x ... (dimN/blockSize) x 1>`.
1246-
1247- In this example, before calling into `arith.scaling_extf`, scales must be
1248- broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249- that there could be multiple quantization axes. Internally,
1250- `arith.scaling_extf` would perform the following:
1232+ This operation upcasts input floating-point values using provided scale
1233+ values. It expects both scales and the input operand to be of the same shape,
1234+ making the operation elementwise. Scales are usually calculated per block
1235+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1236+
1237+ If scales are calculated per block where blockSize != 1, then scales may
1238+ require broadcasting to make this operation elementwise. For example, let's
1239+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240+ assuming quantization happens on the last axis, the input can be reshaped to
1241+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242+ per block on the last axis. Therefore, scales will be of shape
1243+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244+ shape as long as it is broadcast compatible with the input, e.g.,
1245+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
1246+
1247+ In this example, before calling into `arith.scaling_extf`, scales must be
1248+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249+ that there could be multiple quantization axes. Internally,
1250+ `arith.scaling_extf` would perform the following:
12511251
12521252 ```
12531253 resultTy = get_type(result)
@@ -1260,6 +1260,17 @@ def Arith_ScalingExtFOp
12601260 ```
12611261 It propagates NaN values. Therefore, if either scale or the input element
12621262 contains NaN, then the output element value will also be a NaN.
1263+
1264+ Example:
1265+
1266+ ```mlir
1267+ // Upcast from f4E2M1FN to f32.
1268+ %a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
1269+
1270+ // Broadcasting to perform eltwise upcasting (Block size = 32)
1271+ %f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
1272+ %h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
1273+ ```
12631274 }];
12641275 let hasVerifier = 1;
12651276 let assemblyFormat =
0 commit comments