|
1 | | -from collections.abc import Callable |
2 | 1 | from dataclasses import dataclass |
3 | 2 | from math import prod |
4 | 3 | from typing import Optional |
5 | 4 | import warnings |
6 | 5 | from warnings import warn |
7 | 6 |
|
8 | 7 | import torch |
9 | | -from typing_extensions import deprecated |
10 | 8 |
|
11 | 9 | import bitsandbytes.functional as F |
12 | 10 |
|
@@ -50,66 +48,9 @@ def get_current_outlier_idx(self): |
50 | 48 | return torch.Tensor(list(self.outliers)).to(torch.int64) |
51 | 49 |
|
52 | 50 |
|
53 | | -@deprecated( |
54 | | - "This function is deprecated and will be removed in a future release.", |
55 | | - category=FutureWarning, |
56 | | -) |
57 | | -def get_inverse_transform_indices( |
58 | | - transform_tile: Callable[[torch.Tensor], torch.Tensor], |
59 | | - tile_size: tuple[int, int], |
60 | | -): |
61 | | - """ |
62 | | - Compute a permutation of indices that invert the specified (tiled) matrix transformation |
63 | | -
|
64 | | - :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2] |
65 | | - :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere |
66 | | - :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size |
67 | | - :example: transform_tile function for the turing layout (bitsandbytes.functional as F) |
68 | | - :returns: indices |
69 | | - """ |
70 | | - d1, d2 = tile_size |
71 | | - assert 0 < d1 * d2 < 2**64 |
72 | | - tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2) |
73 | | - # encode each position in tile as a tuple of <= 8 unique bytes |
74 | | - permuted_tile_indices = torch.zeros_like(tile_indices) |
75 | | - for i in range(8): |
76 | | - # select i-th byte, apply transformation and trace where each index ended up |
77 | | - ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256 |
78 | | - sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous() |
79 | | - assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow" |
80 | | - permuted_tile_i = transform_tile(sample_tile_i) |
81 | | - ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128 |
82 | | - permuted_tile_indices += ith_permuted_indices * (256**i) |
83 | | - if d1 * d2 < 256**i: |
84 | | - break # if all indices fit in i bytes, stop early |
85 | | - return permuted_tile_indices |
86 | | - |
87 | | - |
88 | 51 | _is_compiling = torch.compiler.is_compiling |
89 | 52 |
|
90 | 53 |
|
91 | | -@deprecated( |
92 | | - "This function is deprecated and will be removed in a future release.", |
93 | | - category=FutureWarning, |
94 | | -) |
95 | | -def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: |
96 | | - """ |
97 | | - Undo a tiled permutation such as turing or ampere layout |
98 | | -
|
99 | | - :param permuted_tensor: torch tensor in a permuted layout |
100 | | - :param tile_indices: reverse transformation indices, from get_inverse_transform_indices |
101 | | - :return: contiguous row-major tensor |
102 | | - """ |
103 | | - (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape |
104 | | - assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" |
105 | | - tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() |
106 | | - outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda |
107 | | - outputs[tile_indices.flatten()] = tensor |
108 | | - outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows) |
109 | | - outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols) |
110 | | - return outputs.reshape(rows, cols).contiguous() |
111 | | - |
112 | | - |
113 | 54 | @dataclass |
114 | 55 | class MatmulLtState: |
115 | 56 | _tile_indices: Optional[torch.Tensor] = None # TODO: remove |
@@ -433,7 +374,7 @@ def matmul_4bit( |
433 | 374 | bias: Optional[torch.Tensor] = None, |
434 | 375 | ): |
435 | 376 | assert quant_state is not None |
436 | | - # Change dtype to bfloat16 on CPU |
| 377 | + # Change dtype to input dtype on CPU |
437 | 378 | if A.device.type == "cpu": |
438 | 379 | quant_state.dtype = A.dtype |
439 | 380 |
|
|
0 commit comments