Skip to content

Commit a93b91f

Browse files
small perf optimization for single-GPU systems
1 parent ed922b8 commit a93b91f

File tree

1 file changed

+64
-57
lines changed

1 file changed

+64
-57
lines changed

bitsandbytes/functional.py

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,16 @@ def get_instance(cls):
191191

192192
FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
193193

194+
if torch.cuda.device_count() > 1:
195+
196+
def _cuda_device_of(a: torch.Tensor):
197+
return torch.cuda.device_of(a)
198+
else:
199+
import contextlib
200+
201+
def _cuda_device_of(a: torch.Tensor):
202+
return contextlib.nullcontext()
203+
194204

195205
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
196206
num_bytes = dtype2bytes[dtype] * prod(shape)
@@ -881,7 +891,7 @@ def quantize_blockwise(
881891

882892
is_on_gpu([A, out, absmax])
883893

884-
with torch.cuda.device_of(A):
894+
with _cuda_device_of(A):
885895
args = (
886896
get_ptr(code),
887897
get_ptr(A),
@@ -992,7 +1002,7 @@ def dequantize_blockwise(
9921002

9931003
is_on_gpu([A, absmax, out])
9941004

995-
with torch.cuda.device_of(A):
1005+
with _cuda_device_of(A):
9961006
args = (
9971007
get_ptr(quant_state.code),
9981008
get_ptr(A),
@@ -1183,7 +1193,7 @@ def quantize_4bit(
11831193

11841194
is_on_gpu([A, out, absmax])
11851195

1186-
with torch.cuda.device_of(A):
1196+
with _cuda_device_of(A):
11871197
args = (
11881198
get_ptr(None),
11891199
get_ptr(A),
@@ -1330,7 +1340,7 @@ def dequantize_4bit(
13301340
is_on_gpu([A, absmax, out])
13311341
stream = _get_tensor_stream(A)
13321342

1333-
with torch.cuda.device_of(A):
1343+
with _cuda_device_of(A):
13341344
args = (
13351345
get_ptr(None),
13361346
get_ptr(A),
@@ -1547,28 +1557,28 @@ def optimizer_update_32bit(
15471557
)
15481558

15491559
is_on_gpu([g, p, state1, state2, unorm_vec])
1550-
prev_device = pre_call(g.device)
1551-
optim_func(
1552-
get_ptr(g),
1553-
get_ptr(p),
1554-
get_ptr(state1),
1555-
get_ptr(state2),
1556-
get_ptr(unorm_vec),
1557-
ct.c_float(max_unorm),
1558-
ct.c_float(param_norm),
1559-
ct.c_float(beta1),
1560-
ct.c_float(beta2),
1561-
ct.c_float(beta3),
1562-
ct.c_float(alpha),
1563-
ct.c_float(eps),
1564-
ct.c_float(weight_decay),
1565-
ct.c_int32(step),
1566-
ct.c_float(lr),
1567-
ct.c_float(gnorm_scale),
1568-
ct.c_bool(skip_zeros),
1569-
ct.c_int32(g.numel()),
1570-
)
1571-
post_call(prev_device)
1560+
1561+
with _cuda_device_of(g):
1562+
optim_func(
1563+
get_ptr(g),
1564+
get_ptr(p),
1565+
get_ptr(state1),
1566+
get_ptr(state2),
1567+
get_ptr(unorm_vec),
1568+
ct.c_float(max_unorm),
1569+
ct.c_float(param_norm),
1570+
ct.c_float(beta1),
1571+
ct.c_float(beta2),
1572+
ct.c_float(beta3),
1573+
ct.c_float(alpha),
1574+
ct.c_float(eps),
1575+
ct.c_float(weight_decay),
1576+
ct.c_int32(step),
1577+
ct.c_float(lr),
1578+
ct.c_float(gnorm_scale),
1579+
ct.c_bool(skip_zeros),
1580+
ct.c_int32(g.numel()),
1581+
)
15721582

15731583

15741584
@deprecated(
@@ -1731,8 +1741,7 @@ def optimizer_update_8bit_blockwise(
17311741
skip_zeros=False,
17321742
) -> None:
17331743
optim_func = None
1734-
prev_device = pre_call(g.device)
1735-
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
1744+
17361745
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
17371746
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
17381747
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
@@ -1747,33 +1756,31 @@ def optimizer_update_8bit_blockwise(
17471756
raise ValueError(
17481757
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
17491758
)
1750-
post_call(prev_device)
17511759

17521760
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
17531761

1754-
prev_device = pre_call(g.device)
1755-
optim_func(
1756-
get_ptr(p),
1757-
get_ptr(g),
1758-
get_ptr(state1),
1759-
get_ptr(state2),
1760-
ct.c_float(beta1),
1761-
ct.c_float(beta2),
1762-
ct.c_float(beta3),
1763-
ct.c_float(alpha),
1764-
ct.c_float(eps),
1765-
ct.c_int32(step),
1766-
ct.c_float(lr),
1767-
get_ptr(qmap1),
1768-
get_ptr(qmap2),
1769-
get_ptr(absmax1),
1770-
get_ptr(absmax2),
1771-
ct.c_float(weight_decay),
1772-
ct.c_float(gnorm_scale),
1773-
ct.c_bool(skip_zeros),
1774-
ct.c_int32(g.numel()),
1775-
)
1776-
post_call(prev_device)
1762+
with _cuda_device_of(g):
1763+
optim_func(
1764+
get_ptr(p),
1765+
get_ptr(g),
1766+
get_ptr(state1),
1767+
get_ptr(state2),
1768+
ct.c_float(beta1),
1769+
ct.c_float(beta2),
1770+
ct.c_float(beta3),
1771+
ct.c_float(alpha),
1772+
ct.c_float(eps),
1773+
ct.c_int32(step),
1774+
ct.c_float(lr),
1775+
get_ptr(qmap1),
1776+
get_ptr(qmap2),
1777+
get_ptr(absmax1),
1778+
get_ptr(absmax2),
1779+
ct.c_float(weight_decay),
1780+
ct.c_float(gnorm_scale),
1781+
ct.c_bool(skip_zeros),
1782+
ct.c_int32(g.numel()),
1783+
)
17771784

17781785

17791786
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
@@ -1966,7 +1973,7 @@ def gemv_4bit(
19661973
ldc = ct.c_int32(ldc)
19671974
stream = _get_tensor_stream(A)
19681975

1969-
with torch.cuda.device_of(A):
1976+
with _cuda_device_of(A):
19701977
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
19711978
if A.dtype == torch.float16:
19721979
lib.cgemm_4bit_inference_naive_fp16(
@@ -2285,7 +2292,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
22852292

22862293
is_on_gpu([A, B, out])
22872294

2288-
with torch.cuda.device_of(A):
2295+
with _cuda_device_of(A):
22892296
ctx = CUBLAS_Context.get_instance().get_context(A.device)
22902297
ptrA = get_ptr(A)
22912298
ptrB = get_ptr(B)
@@ -2343,7 +2350,7 @@ def int8_mm_dequant(
23432350

23442351
is_on_gpu([A, row_stats, col_stats, out, bias])
23452352

2346-
with torch.cuda.device_of(A):
2353+
with _cuda_device_of(A):
23472354
lib.cdequant_mm_int32_fp16(
23482355
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
23492356
)
@@ -2407,7 +2414,7 @@ def get_row_absmax(A: torch.Tensor, threshold=0.0):
24072414

24082415
is_on_gpu([A])
24092416

2410-
with torch.cuda.device_of(A):
2417+
with _cuda_device_of(A):
24112418
lib.cget_row_stats(
24122419
get_ptr(A),
24132420
get_ptr(row_stats),
@@ -2550,7 +2557,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25502557
if outliers.any():
25512558
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
25522559

2553-
with torch.cuda.device_of(A):
2560+
with _cuda_device_of(A):
25542561
lib.cint8_vector_quant(
25552562
get_ptr(A),
25562563
get_ptr(out_row),

0 commit comments

Comments
 (0)