We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 161db9f commit a7730c0Copy full SHA for a7730c0
modelopt/onnx/quantization/qdq_utils.py
@@ -612,7 +612,11 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
612
613
614
def _cast_fp4(array: np.ndarray) -> np.ndarray:
615
- """Cast a numpy array to FLOAT4E2M1 using PyTorch."""
+ """Cast a numpy array to FLOAT4E2M1 using PyTorch.
616
+
617
+ Note: The first dimension of the array must be divisible by 2
618
+ as two FP4 values are packed into a single byte.
619
+ """
620
array_f32_t = torch.from_numpy(array)
621
array_f32_t_shape = array_f32_t.shape
622
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"
0 commit comments