@@ -2429,7 +2429,32 @@ def get_colrow_absmax(
24292429 nnz_block_ptr : Optional [torch .Tensor ] = None ,
24302430 threshold = 0.0 ,
24312431) -> Tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
2432- # Note: prior impl only works with fp16
2432+ """ "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
2433+
2434+ The row-wise and column-wise absmax values are determined.
2435+
2436+ For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
2437+
2438+ <Tip>
2439+ This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead.
2440+ The column-wise quantization scales are not typically needed in inference scenarios.
2441+ </Tip>
2442+
2443+ Args:
2444+ A (`torch.Tensor` with dtype `torch.float16`): Input tensor.
2445+ row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
2446+ col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
2447+ nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
2448+ threshold (`float`, `optional`):
2449+ An optional threshold for sparse decomposition of outlier features.
2450+ No outliers are held back when 0.0. Defaults to 0.0.
2451+
2452+ Returns:
2453+ `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics.
2454+ - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics.
2455+ - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics.
2456+ - `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor.
2457+ """
24332458 assert A .is_floating_point ()
24342459
24352460 outlier_mask = None
0 commit comments