Skip to content

Commit 76fb84a

Browse files
Additional deprecations/removals.
1 parent feb1139 commit 76fb84a

File tree

3 files changed

+134
-385
lines changed

3 files changed

+134
-385
lines changed

bitsandbytes/functional.py

Lines changed: 134 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,6 @@ def fill(A, value, device=None, prefetch=True):
251251
elementwise_func("fill", A, None, value)
252252

253253

254-
@deprecated("Function will be removed in a future release.", category=FutureWarning)
255-
def arange(A, device=None):
256-
elementwise_func("arange", A, None, 0)
257-
258-
259-
@deprecated("Function will be removed in a future release.", category=FutureWarning)
260254
def _mul(A, B, device=None):
261255
elementwise_func("_mul", A, B, 0)
262256

@@ -407,6 +401,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
407401
return torch.tensor(data, dtype=torch.float32)
408402

409403

404+
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
410405
def create_quantile_map(A, total_bits=8):
411406
q = estimate_quantiles(A, num_quantiles=2**total_bits - 1)
412407
q = q.tolist()
@@ -480,17 +475,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
480475

481476

482477
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
483-
def pre_call(device):
484-
prev_device = torch.cuda.current_device()
485-
torch.cuda.set_device(device)
486-
return prev_device
487-
488-
489-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
490-
def post_call(prev_device):
491-
torch.cuda.set_device(prev_device)
492-
493-
494478
def estimate_quantiles(
495479
A: Tensor,
496480
out: Optional[torch.Tensor] = None,
@@ -539,15 +523,16 @@ def estimate_quantiles(
539523

540524
if out is None:
541525
out = torch.zeros((256,), dtype=torch.float32, device=A.device)
542-
is_on_gpu([A, out])
543-
device = pre_call(A.device)
544-
if A.dtype == torch.float32:
545-
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
546-
elif A.dtype == torch.float16:
547-
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
548-
else:
549-
raise NotImplementedError(f"Not supported data type {A.dtype}")
550-
post_call(device)
526+
527+
with _cuda_device_of(A):
528+
is_on_gpu([A, out])
529+
530+
if A.dtype == torch.float32:
531+
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
532+
elif A.dtype == torch.float16:
533+
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
534+
else:
535+
raise NotImplementedError(f"Not supported data type {A.dtype}")
551536

552537
if num_quantiles < 256:
553538
step = round(256 / num_quantiles)
@@ -1219,12 +1204,12 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
12191204
torch.Tensor:
12201205
Quantized 8-bit tensor.
12211206
"""
1222-
prev_device = pre_call(A.device)
1223-
if out is None:
1224-
out = torch.zeros_like(A, dtype=torch.uint8)
1225-
is_on_gpu([A, out])
1226-
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
1227-
post_call(prev_device)
1207+
with _cuda_device_of(A):
1208+
if out is None:
1209+
out = torch.zeros_like(A, dtype=torch.uint8)
1210+
is_on_gpu([A, out])
1211+
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
1212+
12281213
return out
12291214

12301215

@@ -1250,13 +1235,13 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
12501235
torch.Tensor:
12511236
32-bit output tensor.
12521237
"""
1253-
prev_device = pre_call(A.device)
1254-
if out is None:
1255-
out = torch.zeros_like(A, dtype=torch.float32)
1256-
is_on_gpu([code, A, out])
1257-
stream = _get_tensor_stream(A)
1258-
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream)
1259-
post_call(prev_device)
1238+
with _cuda_device_of(A):
1239+
if out is None:
1240+
out = torch.zeros_like(A, dtype=torch.float32)
1241+
is_on_gpu([code, A, out])
1242+
stream = _get_tensor_stream(A)
1243+
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream)
1244+
12601245
return out
12611246

12621247

@@ -1444,61 +1429,60 @@ def optimizer_update_8bit(
14441429
if max_unorm > 0.0:
14451430
param_norm = torch.norm(p.data.float())
14461431

1447-
prev_device = pre_call(g.device)
1448-
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
1449-
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
1450-
str2optimizer8bit[optimizer_name][0](
1451-
get_ptr(p),
1452-
get_ptr(g),
1453-
get_ptr(state1),
1454-
get_ptr(state2),
1455-
get_ptr(unorm_vec),
1456-
ct.c_float(max_unorm),
1457-
ct.c_float(param_norm),
1458-
ct.c_float(beta1),
1459-
ct.c_float(beta2),
1460-
ct.c_float(eps),
1461-
ct.c_int32(step),
1462-
ct.c_float(lr),
1463-
get_ptr(qmap1),
1464-
get_ptr(qmap2),
1465-
get_ptr(max1),
1466-
get_ptr(max2),
1467-
get_ptr(new_max1),
1468-
get_ptr(new_max2),
1469-
ct.c_float(weight_decay),
1470-
ct.c_float(gnorm_scale),
1471-
ct.c_int32(g.numel()),
1472-
)
1473-
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
1474-
str2optimizer8bit[optimizer_name][1](
1475-
get_ptr(p),
1476-
get_ptr(g),
1477-
get_ptr(state1),
1478-
get_ptr(state2),
1479-
get_ptr(unorm_vec),
1480-
ct.c_float(max_unorm),
1481-
ct.c_float(param_norm),
1482-
ct.c_float(beta1),
1483-
ct.c_float(beta2),
1484-
ct.c_float(eps),
1485-
ct.c_int32(step),
1486-
ct.c_float(lr),
1487-
get_ptr(qmap1),
1488-
get_ptr(qmap2),
1489-
get_ptr(max1),
1490-
get_ptr(max2),
1491-
get_ptr(new_max1),
1492-
get_ptr(new_max2),
1493-
ct.c_float(weight_decay),
1494-
ct.c_float(gnorm_scale),
1495-
ct.c_int32(g.numel()),
1496-
)
1497-
else:
1498-
raise ValueError(
1499-
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
1500-
)
1501-
post_call(prev_device)
1432+
with _cuda_device_of(g):
1433+
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
1434+
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
1435+
str2optimizer8bit[optimizer_name][0](
1436+
get_ptr(p),
1437+
get_ptr(g),
1438+
get_ptr(state1),
1439+
get_ptr(state2),
1440+
get_ptr(unorm_vec),
1441+
ct.c_float(max_unorm),
1442+
ct.c_float(param_norm),
1443+
ct.c_float(beta1),
1444+
ct.c_float(beta2),
1445+
ct.c_float(eps),
1446+
ct.c_int32(step),
1447+
ct.c_float(lr),
1448+
get_ptr(qmap1),
1449+
get_ptr(qmap2),
1450+
get_ptr(max1),
1451+
get_ptr(max2),
1452+
get_ptr(new_max1),
1453+
get_ptr(new_max2),
1454+
ct.c_float(weight_decay),
1455+
ct.c_float(gnorm_scale),
1456+
ct.c_int32(g.numel()),
1457+
)
1458+
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
1459+
str2optimizer8bit[optimizer_name][1](
1460+
get_ptr(p),
1461+
get_ptr(g),
1462+
get_ptr(state1),
1463+
get_ptr(state2),
1464+
get_ptr(unorm_vec),
1465+
ct.c_float(max_unorm),
1466+
ct.c_float(param_norm),
1467+
ct.c_float(beta1),
1468+
ct.c_float(beta2),
1469+
ct.c_float(eps),
1470+
ct.c_int32(step),
1471+
ct.c_float(lr),
1472+
get_ptr(qmap1),
1473+
get_ptr(qmap2),
1474+
get_ptr(max1),
1475+
get_ptr(max2),
1476+
get_ptr(new_max1),
1477+
get_ptr(new_max2),
1478+
ct.c_float(weight_decay),
1479+
ct.c_float(gnorm_scale),
1480+
ct.c_int32(g.numel()),
1481+
)
1482+
else:
1483+
raise ValueError(
1484+
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
1485+
)
15021486

15031487

15041488
def optimizer_update_8bit_blockwise(
@@ -1577,25 +1561,24 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
15771561
The current optimization steps (number of past gradient norms).
15781562
15791563
"""
1580-
prev_device = pre_call(grad.device)
1581-
is_on_gpu([grad, gnorm_vec])
1582-
if grad.dtype == torch.float32:
1583-
lib.cpercentile_clipping_g32(
1584-
get_ptr(grad),
1585-
get_ptr(gnorm_vec),
1586-
ct.c_int32(step),
1587-
ct.c_int32(grad.numel()),
1588-
)
1589-
elif grad.dtype == torch.float16:
1590-
lib.cpercentile_clipping_g16(
1591-
get_ptr(grad),
1592-
get_ptr(gnorm_vec),
1593-
ct.c_int32(step),
1594-
ct.c_int32(grad.numel()),
1595-
)
1596-
else:
1597-
raise ValueError(f"Gradient type {grad.dtype} not supported!")
1598-
post_call(prev_device)
1564+
with _cuda_device_of(grad):
1565+
is_on_gpu([grad, gnorm_vec])
1566+
if grad.dtype == torch.float32:
1567+
lib.cpercentile_clipping_g32(
1568+
get_ptr(grad),
1569+
get_ptr(gnorm_vec),
1570+
ct.c_int32(step),
1571+
ct.c_int32(grad.numel()),
1572+
)
1573+
elif grad.dtype == torch.float16:
1574+
lib.cpercentile_clipping_g16(
1575+
get_ptr(grad),
1576+
get_ptr(gnorm_vec),
1577+
ct.c_int32(step),
1578+
ct.c_int32(grad.numel()),
1579+
)
1580+
else:
1581+
raise ValueError(f"Gradient type {grad.dtype} not supported!")
15991582

16001583
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
16011584
vals, idx = torch.sort(gnorm_vec)
@@ -2333,7 +2316,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
23332316
if out is None:
23342317
out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
23352318
nnz = cooA.nnz
2336-
prev_device = pre_call(B.device)
2319+
23372320
assert cooA.rowidx.numel() == nnz
23382321
assert cooA.colidx.numel() == nnz
23392322
assert cooA.values.numel() == nnz
@@ -2370,43 +2353,43 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
23702353
cldb = ct.c_int32(ldb)
23712354
cldc = ct.c_int32(ldc)
23722355

2373-
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
2374-
if B.dtype == torch.float16:
2375-
lib.cspmm_coo_very_sparse_naive_fp16(
2376-
ptrMaxCount,
2377-
ptrMaxIdx,
2378-
ptrOffset,
2379-
ptrRowidx,
2380-
ptrColidx,
2381-
ptrValues,
2382-
ptrB,
2383-
ptrC,
2384-
ptrDequantStats,
2385-
cnnz_rows,
2386-
cnnz,
2387-
crowsA,
2388-
crowsB,
2389-
ccolsB,
2390-
)
2391-
elif B.dtype == torch.int8:
2392-
lib.cspmm_coo_very_sparse_naive_int8(
2393-
ptrMaxCount,
2394-
ptrMaxIdx,
2395-
ptrOffset,
2396-
ptrRowidx,
2397-
ptrColidx,
2398-
ptrValues,
2399-
ptrB,
2400-
ptrC,
2401-
ptrDequantStats,
2402-
cnnz_rows,
2403-
cnnz,
2404-
crowsA,
2405-
crowsB,
2406-
ccolsB,
2407-
)
2408-
# else: assertion error
2409-
post_call(prev_device)
2356+
with _cuda_device_of(B):
2357+
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
2358+
if B.dtype == torch.float16:
2359+
lib.cspmm_coo_very_sparse_naive_fp16(
2360+
ptrMaxCount,
2361+
ptrMaxIdx,
2362+
ptrOffset,
2363+
ptrRowidx,
2364+
ptrColidx,
2365+
ptrValues,
2366+
ptrB,
2367+
ptrC,
2368+
ptrDequantStats,
2369+
cnnz_rows,
2370+
cnnz,
2371+
crowsA,
2372+
crowsB,
2373+
ccolsB,
2374+
)
2375+
elif B.dtype == torch.int8:
2376+
lib.cspmm_coo_very_sparse_naive_int8(
2377+
ptrMaxCount,
2378+
ptrMaxIdx,
2379+
ptrOffset,
2380+
ptrRowidx,
2381+
ptrColidx,
2382+
ptrValues,
2383+
ptrB,
2384+
ptrC,
2385+
ptrDequantStats,
2386+
cnnz_rows,
2387+
cnnz,
2388+
crowsA,
2389+
crowsB,
2390+
ccolsB,
2391+
)
2392+
# else: assertion error
24102393

24112394
return out
24122395

@@ -2463,18 +2446,6 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
24632446
return None
24642447

24652448

2466-
@deprecated(
2467-
"This function is deprecated and will be removed in a future release.",
2468-
category=FutureWarning,
2469-
)
2470-
def vectorwise_dequant(xq, max1, quant_type="vector"):
2471-
if quant_type == "vector":
2472-
x = (xq / C * max1).to(torch.float32)
2473-
return x
2474-
else:
2475-
return None
2476-
2477-
24782449
@deprecated(
24792450
"This function is deprecated and will be removed in a future release.",
24802451
category=FutureWarning,

0 commit comments

Comments
 (0)