You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
2459
+
2460
+
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
2461
+
2462
+
Args:
2463
+
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
2464
+
threshold (`float`, *optional*):
2465
+
An optional threshold for sparse decomposition of outlier features.
2466
+
No outliers are held back when 0.0. Defaults to 0.0.
2467
+
2468
+
Returns:
2469
+
`torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
2470
+
"""
2471
+
2412
2472
assertA.dtype==torch.float16
2413
2473
2414
2474
rows=prod(A.shape[:-1])
@@ -2520,6 +2580,37 @@ def double_quant(
2520
2580
out_row: Optional[torch.Tensor] =None,
2521
2581
threshold=0.0,
2522
2582
):
2583
+
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
2584
+
2585
+
The statistics are determined both row-wise and column-wise (transposed).
2586
+
2587
+
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
2588
+
2589
+
<Tip>
2590
+
This function is useful for training, but for inference it is advised to use [`int8_vectorwise_quant`] instead.
2591
+
This implementation performs additional column-wise transposed calculations which are not optimized.
2592
+
</Tip>
2593
+
2594
+
Args:
2595
+
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
2596
+
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
2597
+
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
2598
+
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
2599
+
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
2600
+
threshold (`float`, *optional*):
2601
+
An optional threshold for sparse decomposition of outlier features.
2602
+
2603
+
No outliers are held back when 0.0. Defaults to 0.0.
2604
+
2605
+
Returns:
2606
+
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
2607
+
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
2608
+
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
2609
+
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
2610
+
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
2611
+
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
2612
+
"""
2613
+
2523
2614
# TODO: Optimize/write CUDA kernel for this?
2524
2615
# Note: for inference, use the new int8_vectorwise_quant.
0 commit comments