Skip to content

Commit f61d8bc

Browse files
update docstrings
1 parent 4bced86 commit f61d8bc

File tree

1 file changed

+112
-3
lines changed

1 file changed

+112
-3
lines changed

bitsandbytes/functional.py

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,20 @@ def get_special_format_str():
437437

438438

439439
def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
440+
"""Verifies that the input tensors are all on the same device.
441+
442+
An input tensor may also be marked as `paged`, in which case the device placement is ignored.
443+
444+
Args:
445+
tensors (Iterable[Optional[torch.Tensor]]): A list of tensors to verify.
446+
447+
Raises:
448+
`RuntimeError`: Raised when the verification fails.
449+
450+
Returns:
451+
`Literal[True]`
452+
"""
453+
440454
on_gpu = True
441455
gpu_ids = set()
442456

@@ -1199,7 +1213,7 @@ def quantize_4bit(
11991213

12001214
with _cuda_device_of(A):
12011215
args = (
1202-
get_ptr(None),
1216+
None,
12031217
get_ptr(A),
12041218
get_ptr(absmax),
12051219
get_ptr(out),
@@ -1346,7 +1360,7 @@ def dequantize_4bit(
13461360

13471361
with _cuda_device_of(A):
13481362
args = (
1349-
get_ptr(None),
1363+
None,
13501364
get_ptr(A),
13511365
get_ptr(absmax),
13521366
get_ptr(out),
@@ -2255,6 +2269,25 @@ def igemmlt(
22552269

22562270

22572271
def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
2272+
"""Performs an 8-bit integer matrix multiplication.
2273+
2274+
A linear transformation is applied such that `out = A @ B.T`. When possible, integer tensor core hardware is
2275+
utilized to accelerate the operation.
2276+
2277+
Args:
2278+
A (`torch.Tensor`): The first matrix operand with the data type `torch.int8`.
2279+
B (`torch.Tensor`): The second matrix operand with the data type `torch.int8`.
2280+
out (`torch.Tensor, *optional*): A pre-allocated tensor used to store the result.
2281+
dtype (`torch.dtype`, *optional*): The expected data type of the output. Defaults to `torch.int32`.
2282+
2283+
Raises:
2284+
`NotImplementedError`: The operation is not supported in the current environment.
2285+
`RuntimeError`: Raised when the cannot be completed for any other reason.
2286+
2287+
Returns:
2288+
`torch.Tensor`: The result of the operation.
2289+
"""
2290+
22582291
#
22592292
# To use the IMMA tensor core kernels without special Turing/Ampere layouts,
22602293
# cublasLt has some rules, namely: A must be transposed, B must not be transposed.
@@ -2336,6 +2369,19 @@ def int8_mm_dequant(
23362369
out: Optional[torch.Tensor] = None,
23372370
bias: Optional[torch.Tensor] = None,
23382371
):
2372+
"""Performs dequantization on the result of a quantized int8 matrix multiplication.
2373+
2374+
Args:
2375+
A (`torch.Tensor` with dtype `torch.int32`): The result of a quantized int8 matrix multiplication.
2376+
row_stats (`torch.Tensor`): The row-wise quantization statistics for the lhs operand of the matrix multiplication.
2377+
col_stats (`torch.Tensor`): The column-wise quantization statistics for the rhs operand of the matrix multiplication.
2378+
out (`torch.Tensor], *optional*): A pre-allocated tensor to store the output of the operation.
2379+
bias (`torch.Tensor`, *optional*): An optional bias vector to add to the result.
2380+
2381+
Returns:
2382+
`torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`.
2383+
"""
2384+
23392385
assert A.dtype == torch.int32
23402386

23412387
if bias is not None:
@@ -2409,6 +2455,20 @@ def get_colrow_absmax(
24092455

24102456

24112457
def get_row_absmax(A: torch.Tensor, threshold=0.0):
2458+
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
2459+
2460+
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
2461+
2462+
Args:
2463+
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
2464+
threshold (`float`, *optional*):
2465+
An optional threshold for sparse decomposition of outlier features.
2466+
No outliers are held back when 0.0. Defaults to 0.0.
2467+
2468+
Returns:
2469+
`torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
2470+
"""
2471+
24122472
assert A.dtype == torch.float16
24132473

24142474
rows = prod(A.shape[:-1])
@@ -2520,6 +2580,37 @@ def double_quant(
25202580
out_row: Optional[torch.Tensor] = None,
25212581
threshold=0.0,
25222582
):
2583+
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
2584+
2585+
The statistics are determined both row-wise and column-wise (transposed).
2586+
2587+
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
2588+
2589+
<Tip>
2590+
This function is useful for training, but for inference it is advised to use [`int8_vectorwise_quant`] instead.
2591+
This implementation performs additional column-wise transposed calculations which are not optimized.
2592+
</Tip>
2593+
2594+
Args:
2595+
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
2596+
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
2597+
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
2598+
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
2599+
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
2600+
threshold (`float`, *optional*):
2601+
An optional threshold for sparse decomposition of outlier features.
2602+
2603+
No outliers are held back when 0.0. Defaults to 0.0.
2604+
2605+
Returns:
2606+
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
2607+
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
2608+
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
2609+
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
2610+
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
2611+
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
2612+
"""
2613+
25232614
# TODO: Optimize/write CUDA kernel for this?
25242615
# Note: for inference, use the new int8_vectorwise_quant.
25252616

@@ -2541,6 +2632,24 @@ def double_quant(
25412632

25422633

25432634
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
2635+
"""Quantizes a tensor with dtype `torch.float16` to `torch.int8` in accordance to the `LLM.int8()` algorithm.
2636+
2637+
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
2638+
2639+
Args:
2640+
A (`torch.Tensor` with dtype `torch.float16`): The input tensor.
2641+
threshold (`float`, *optional*):
2642+
An optional threshold for sparse decomposition of outlier features.
2643+
2644+
No outliers are held back when 0.0. Defaults to 0.0.
2645+
2646+
Returns:
2647+
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
2648+
- `torch.Tensor` with dtype `torch.int8`: The quantized data.
2649+
- `torch.Tensor` with dtype `torch.float32`: The quantization scales.
2650+
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
2651+
"""
2652+
25442653
assert A.dtype == torch.half
25452654
is_on_gpu([A])
25462655

@@ -2838,7 +2947,7 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):
28382947

28392948

28402949
@deprecated(
2841-
"This function is deprecated and will be removed in a future release. Consider using `mm_dequant` instead.",
2950+
"This function is deprecated and will be removed in a future release. Consider using `int8_mm_dequant` instead.",
28422951
category=FutureWarning,
28432952
)
28442953
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):

0 commit comments

Comments
 (0)