Skip to content

Commit 3a45ff5

Browse files
committed
Update documentation for _cast_fp4
Signed-off-by: ajrasane <[email protected]>
1 parent c06fcac commit 3a45ff5

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,11 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
612612

613613

614614
def _cast_fp4(array: np.ndarray) -> np.ndarray:
615-
"""Cast a numpy array to FLOAT4E2M1 using PyTorch."""
615+
"""Cast a numpy array to FLOAT4E2M1 using PyTorch.
616+
617+
Note: The first dimension of the array must be divisible by 2
618+
as two FP4 values are packed into a single byte.
619+
"""
616620
array_f32_t = torch.from_numpy(array)
617621
array_f32_t_shape = array_f32_t.shape
618622
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"

0 commit comments

Comments
 (0)