Skip to content

Commit 2bb5c00

Browse files
committed
Added pre/post call to all lib calls. Fixes #120
1 parent 29ab3a6 commit 2bb5c00

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

bitsandbytes/functional.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,8 @@ def optimizer_update_32bit(
770770
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
771771
)
772772

773+
prev_device = pre_call(g.device)
774+
is_on_gpu([g, p, state1, state2, unorm_vec])
773775
if g.dtype == torch.float32 and state1.dtype == torch.float32:
774776
str2optimizer32bit[optimizer_name][0](
775777
get_ptr(g),
@@ -812,6 +814,7 @@ def optimizer_update_32bit(
812814
raise ValueError(
813815
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
814816
)
817+
post_call(prev_device)
815818

816819

817820
def optimizer_update_8bit(
@@ -890,6 +893,8 @@ def optimizer_update_8bit(
890893
if max_unorm > 0.0:
891894
param_norm = torch.norm(p.data.float())
892895

896+
prev_device = pre_call(g.device)
897+
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
893898
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
894899
str2optimizer8bit[optimizer_name][0](
895900
get_ptr(p),
@@ -942,6 +947,7 @@ def optimizer_update_8bit(
942947
raise ValueError(
943948
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
944949
)
950+
post_call(prev_device)
945951

946952

947953
def optimizer_update_8bit_blockwise(
@@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise(
964970
skip_zeros=False,
965971
) -> None:
966972

973+
prev_device = pre_call(g.device)
974+
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
967975
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
968976
str2optimizer8bit_blockwise[optimizer_name][0](
969977
get_ptr(p),
@@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise(
10081016
raise ValueError(
10091017
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
10101018
)
1019+
post_call(prev_device)
10111020

10121021

10131022
def percentile_clipping(
@@ -1023,6 +1032,7 @@ def percentile_clipping(
10231032
The current optimiation steps (number of past gradient norms).
10241033
10251034
"""
1035+
prev_device = pre_call(grad.device)
10261036
is_on_gpu([grad, gnorm_vec])
10271037
if grad.dtype == torch.float32:
10281038
lib.cpercentile_clipping_g32(
@@ -1040,6 +1050,7 @@ def percentile_clipping(
10401050
)
10411051
else:
10421052
raise ValueError(f"Gradient type {grad.dtype} not supported!")
1053+
post_call(prev_device)
10431054

10441055
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
10451056
vals, idx = torch.sort(gnorm_vec)
@@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
17961807
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
17971808
)
17981809
nnz = cooA.nnz
1810+
prev_device = pre_call(B.device)
17991811
assert cooA.rowidx.numel() == nnz
18001812
assert cooA.colidx.numel() == nnz
18011813
assert cooA.values.numel() == nnz
@@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
18721884
ccolsB,
18731885
)
18741886
# else: assertion error
1887+
post_call(prev_device)
18751888

18761889
return out
18771890

0 commit comments

Comments
 (0)