Skip to content

Commit 242c602

Browse files
Deprecation updates
1 parent 6172770 commit 242c602

File tree

6 files changed

+18
-433
lines changed

6 files changed

+18
-433
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def get_current_outlier_idx(self):
4949
return torch.Tensor(list(self.outliers)).to(torch.int64)
5050

5151

52+
@deprecated(
53+
"This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.",
54+
category=FutureWarning,
55+
)
5256
def get_inverse_transform_indices(
5357
transform_tile: Callable[[torch.Tensor], torch.Tensor],
5458
tile_size: Tuple[int, int],
@@ -80,6 +84,10 @@ def get_inverse_transform_indices(
8084
return permuted_tile_indices
8185

8286

87+
@deprecated(
88+
"This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.",
89+
category=FutureWarning,
90+
)
8391
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
8492
"""
8593
Undo a tiled permutation such as turing or ampere layout
@@ -225,25 +233,9 @@ def supports_igemmlt(device: torch.device) -> bool:
225233
return True
226234

227235

228-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
229-
def _get_tile_size(format):
230-
assert format in (
231-
"col_turing",
232-
"col_ampere",
233-
), f"please find this assert and manually enter tile size for {format}"
234-
return (8, 32) if format == "col_turing" else (32, 32)
235-
236-
237-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
238-
def get_tile_inds(format, device):
239-
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
240-
with torch.no_grad():
241-
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
242-
243-
244236
@dataclass
245237
class MatmulLtState:
246-
_tile_indices: Optional[torch.Tensor] = None
238+
_tile_indices: Optional[torch.Tensor] = None # TODO: remove
247239

248240
force_no_igemmlt: bool = False
249241

@@ -279,9 +271,7 @@ def reset_grads(self):
279271

280272
@property
281273
def tile_indices(self):
282-
if self._tile_indices is None:
283-
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
284-
return self._tile_indices
274+
raise ValueError("tile_indices is no longer supported.")
285275

286276

287277
class MatMul8bitLt(torch.autograd.Function):

bitsandbytes/functional.py

Lines changed: 6 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,6 @@ def get_instance(cls):
182182
return cls._instance
183183

184184

185-
dtype2bytes = {}
186-
dtype2bytes[torch.float32] = 4
187-
dtype2bytes[torch.float16] = 2
188-
dtype2bytes[torch.bfloat16] = 2
189-
dtype2bytes[torch.uint8] = 1
190-
dtype2bytes[torch.int8] = 1
191-
192185
FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
193186

194187
# When multiple GPUs are present, we use a context manager to
@@ -207,7 +200,7 @@ def _cuda_device_of(a: torch.Tensor):
207200

208201

209202
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
210-
num_bytes = dtype2bytes[dtype] * prod(shape)
203+
num_bytes = dtype.itemsize * prod(shape)
211204
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
212205
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
213206
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
@@ -217,15 +210,14 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
217210
return out
218211

219212

220-
def prefetch_tensor(A, to_cpu=False):
213+
def prefetch_tensor(A: torch.Tensor, to_cpu=False):
221214
assert A.is_paged, "Only paged tensors can be prefetched!"
222215
if to_cpu:
223216
deviceid = -1
224217
else:
225218
deviceid = A.page_deviceid
226219

227-
num_bytes = dtype2bytes[A.dtype] * A.numel()
228-
lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
220+
lib.cprefetch(get_ptr(A), ct.c_size_t(A.nbytes), ct.c_int32(deviceid))
229221

230222

231223
def elementwise_func(func_name, A, B, value, prefetch=True):
@@ -499,106 +491,6 @@ def post_call(prev_device):
499491
torch.cuda.set_device(prev_device)
500492

501493

502-
@deprecated(
503-
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
504-
category=FutureWarning,
505-
)
506-
def get_transform_func(dtype, orderA, orderOut, transpose=False):
507-
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
508-
if not hasattr(lib, name):
509-
print(name)
510-
raise ValueError(
511-
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}",
512-
)
513-
else:
514-
return getattr(lib, name)
515-
516-
517-
@deprecated(
518-
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
519-
category=FutureWarning,
520-
)
521-
def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
522-
# init_func = torch.empty
523-
init_func = torch.zeros
524-
dims = len(shape)
525-
526-
if dims == 2:
527-
rows = shape[0]
528-
elif dims == 3:
529-
rows = shape[0] * shape[1]
530-
cols = shape[-1]
531-
532-
state = (shape, to_order)
533-
if transpose:
534-
# swap dims
535-
tmp = rows
536-
rows = cols
537-
cols = tmp
538-
state = (shape[::-1], to_order)
539-
540-
if to_order == "row" or to_order == "col":
541-
return init_func(shape, dtype=dtype, device=device), state
542-
elif to_order == "col32":
543-
# blocks of 32 columns (padded)
544-
cols = 32 * ((cols + 31) // 32)
545-
return init_func((rows, cols), dtype=dtype, device=device), state
546-
elif to_order == "col_turing":
547-
# blocks of 32 columns and 8 rows
548-
cols = 32 * ((cols + 31) // 32)
549-
rows = 8 * ((rows + 7) // 8)
550-
return init_func((rows, cols), dtype=dtype, device=device), state
551-
elif to_order == "col_ampere":
552-
# blocks of 32 columns and 32 rows
553-
cols = 32 * ((cols + 31) // 32)
554-
rows = 32 * ((rows + 31) // 32)
555-
return init_func((rows, cols), dtype=dtype, device=device), state
556-
else:
557-
raise NotImplementedError(f"To_order not supported: {to_order}")
558-
559-
560-
@deprecated(
561-
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
562-
category=FutureWarning,
563-
)
564-
def nvidia_transform(
565-
A,
566-
to_order,
567-
from_order="row",
568-
out=None,
569-
transpose=False,
570-
state=None,
571-
ld=None,
572-
):
573-
if state is None:
574-
state = (A.shape, from_order)
575-
else:
576-
from_order = state[1]
577-
if out is None:
578-
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
579-
else:
580-
new_state = (state[1], to_order)
581-
func = get_transform_func(A.dtype, from_order, to_order, transpose)
582-
583-
shape = state[0]
584-
if len(shape) == 2:
585-
dim1 = ct.c_int32(shape[0])
586-
dim2 = ct.c_int32(shape[1])
587-
elif ld is not None:
588-
n = prod(shape)
589-
dim1 = prod([shape[i] for i in ld])
590-
dim2 = ct.c_int32(n // dim1)
591-
dim1 = ct.c_int32(dim1)
592-
else:
593-
dim1 = ct.c_int32(shape[0] * shape[1])
594-
dim2 = ct.c_int32(shape[2])
595-
596-
ptr = CUBLAS_Context.get_instance().get_context(A.device)
597-
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
598-
599-
return out, new_state
600-
601-
602494
def estimate_quantiles(
603495
A: Tensor,
604496
out: Optional[torch.Tensor] = None,
@@ -1715,6 +1607,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
17151607
return current_gnorm, clip_value, gnorm_scale
17161608

17171609

1610+
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
17181611
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
17191612
assert len(histogram.shape) == 2
17201613
assert histogram.dtype == torch.float32
@@ -2105,6 +1998,7 @@ def int8_mm_dequant(
21051998
return result
21061999

21072000

2001+
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
21082002
def get_colrow_absmax(
21092003
A: torch.Tensor,
21102004
row_stats: Optional[torch.Tensor] = None,
@@ -2162,6 +2056,7 @@ def get_colrow_absmax(
21622056
return row_stats, col_stats, outlier_mask
21632057

21642058

2059+
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
21652060
def get_row_absmax(A: torch.Tensor, threshold=0.0):
21662061
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
21672062
@@ -2366,58 +2261,6 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
23662261
return torch.ops.bitsandbytes.int8_vectorwise_quant.default(A, threshold)
23672262

23682263

2369-
@deprecated(
2370-
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
2371-
category=FutureWarning,
2372-
)
2373-
def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
2374-
prev_device = pre_call(A.device)
2375-
if state is None:
2376-
state = (A.shape, from_order)
2377-
else:
2378-
from_order = state[1]
2379-
if out is None:
2380-
out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
2381-
else:
2382-
new_state = (state[0], to_order) # (shape, order)
2383-
2384-
shape = state[0]
2385-
if len(shape) == 2:
2386-
dim1 = ct.c_int32(shape[0])
2387-
dim2 = ct.c_int32(shape[1])
2388-
else:
2389-
dim1 = ct.c_int32(shape[0] * shape[1])
2390-
dim2 = ct.c_int32(shape[2])
2391-
2392-
is_on_gpu([A, out])
2393-
if to_order == "col32":
2394-
if transpose:
2395-
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
2396-
else:
2397-
lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
2398-
elif to_order == "col_turing":
2399-
if transpose:
2400-
lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
2401-
else:
2402-
lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
2403-
elif to_order == "col_ampere":
2404-
if transpose:
2405-
lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
2406-
else:
2407-
lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
2408-
elif to_order == "row":
2409-
if from_order == "col_turing":
2410-
lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
2411-
elif from_order == "col_ampere":
2412-
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
2413-
else:
2414-
raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}")
2415-
2416-
post_call(prev_device)
2417-
2418-
return out, new_state
2419-
2420-
24212264
def spmm_coo(
24222265
cooA: Union[COOSparseTensor, torch.Tensor],
24232266
B: torch.Tensor,
@@ -2692,29 +2535,3 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
26922535
return x.to(dtype)
26932536
else:
26942537
return None
2695-
2696-
2697-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
2698-
def extract_outliers(A, SA, idx):
2699-
shapeA = SA[0]
2700-
formatA = SA[1]
2701-
assert formatA in ["col_turing", "col_ampere"]
2702-
assert A.device.type == "cuda"
2703-
2704-
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
2705-
2706-
idx_size = ct.c_int32(idx.numel())
2707-
rows = ct.c_int32(shapeA[0])
2708-
cols = ct.c_int32(shapeA[1])
2709-
ptrA = get_ptr(A)
2710-
ptrIdx = get_ptr(idx)
2711-
ptrOut = get_ptr(out)
2712-
2713-
prev_device = pre_call(A.device)
2714-
if formatA == "col_turing":
2715-
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
2716-
elif formatA == "col_ampere":
2717-
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
2718-
post_call(prev_device)
2719-
2720-
return out

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch.nn.functional as F
1212

1313
import bitsandbytes as bnb
14-
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
1514
from bitsandbytes.functional import QuantState
1615
from bitsandbytes.optim import GlobalOptimManager
1716
from bitsandbytes.utils import (
@@ -654,8 +653,7 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
654653
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]
655654

656655
if weight_format != "row":
657-
tile_indices = get_tile_inds(weight_format, weight.device)
658-
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
656+
raise ValueError(f"Only 'row' weight format is supported, got {weight_format}")
659657

660658

661659
class Embedding8bit(nn.Embedding):

0 commit comments

Comments
 (0)