Skip to content

Commit 6e0a4b3

Browse files
Update docstring
1 parent eed9c3c commit 6e0a4b3

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

bitsandbytes/functional.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2429,7 +2429,32 @@ def get_colrow_absmax(
24292429
nnz_block_ptr: Optional[torch.Tensor] = None,
24302430
threshold=0.0,
24312431
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
2432-
# Note: prior impl only works with fp16
2432+
""" "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
2433+
2434+
The row-wise and column-wise absmax values are determined.
2435+
2436+
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
2437+
2438+
<Tip>
2439+
This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead.
2440+
The column-wise quantization scales are not typically needed in inference scenarios.
2441+
</Tip>
2442+
2443+
Args:
2444+
A (`torch.Tensor` with dtype `torch.float16`): Input tensor.
2445+
row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
2446+
col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
2447+
nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
2448+
threshold (`float`, `optional`):
2449+
An optional threshold for sparse decomposition of outlier features.
2450+
No outliers are held back when 0.0. Defaults to 0.0.
2451+
2452+
Returns:
2453+
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics.
2454+
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics.
2455+
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics.
2456+
- `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor.
2457+
"""
24332458
assert A.is_floating_point()
24342459

24352460
outlier_mask = None

0 commit comments

Comments
 (0)