Skip to content

Commit 25e2613

Browse files
Adding example to scaling_extf description
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 79f1484 commit 25e2613

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,25 +1229,25 @@ def Arith_ScalingExtFOp
12291229
let summary = "Upcasts input floats using provided scales values following "
12301230
"OCP MXFP Spec";
12311231
let description = [{
1232-
This operation upcasts input floating-point values using provided scale
1233-
values. It expects both scales and the input operand to be of the same shape,
1234-
making the operation elementwise. Scales are usually calculated per block
1235-
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1236-
1237-
If scales are calculated per block where blockSize != 1, then scales may
1238-
require broadcasting to make this operation elementwise. For example, let's
1239-
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240-
assuming quantization happens on the last axis, the input can be reshaped to
1241-
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242-
per block on the last axis. Therefore, scales will be of shape
1243-
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244-
shape as long as it is broadcast compatible with the input, e.g.,
1245-
`<1 x 1 x ... (dimN/blockSize) x 1>`.
1246-
1247-
In this example, before calling into `arith.scaling_extf`, scales must be
1248-
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249-
that there could be multiple quantization axes. Internally,
1250-
`arith.scaling_extf` would perform the following:
1232+
This operation upcasts input floating-point values using provided scale
1233+
values. It expects both scales and the input operand to be of the same shape,
1234+
making the operation elementwise. Scales are usually calculated per block
1235+
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1236+
1237+
If scales are calculated per block where blockSize != 1, then scales may
1238+
require broadcasting to make this operation elementwise. For example, let's
1239+
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240+
assuming quantization happens on the last axis, the input can be reshaped to
1241+
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242+
per block on the last axis. Therefore, scales will be of shape
1243+
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244+
shape as long as it is broadcast compatible with the input, e.g.,
1245+
`<1 x 1 x ... (dimN/blockSize) x 1>`.
1246+
1247+
In this example, before calling into `arith.scaling_extf`, scales must be
1248+
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249+
that there could be multiple quantization axes. Internally,
1250+
`arith.scaling_extf` would perform the following:
12511251

12521252
```
12531253
resultTy = get_type(result)
@@ -1260,6 +1260,17 @@ def Arith_ScalingExtFOp
12601260
```
12611261
It propagates NaN values. Therefore, if either scale or the input element
12621262
contains NaN, then the output element value will also be a NaN.
1263+
1264+
Example:
1265+
1266+
```mlir
1267+
// Upcast from f4E2M1FN to f32.
1268+
%a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32
1269+
1270+
// Broadcasting to perform eltwise upcasting (Block size = 32)
1271+
%f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
1272+
%h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
1273+
```
12631274
}];
12641275
let hasVerifier = 1;
12651276
let assemblyFormat =

0 commit comments

Comments
 (0)