Skip to content

Commit 62a333a

Browse files
committed
Added pre/post calls do quantize_blockwise.
1 parent e0e697b commit 62a333a

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

bitsandbytes/functional.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
458458
"""
459459

460460

461+
prev_device = pre_call(A.device)
461462
if code is None:
462463
if "dynamic" not in name2qmap:
463464
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -479,6 +480,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
479480
is_on_gpu([code, A, absmax, out, rand])
480481
cblocksize = ct.c_int32(blocksize)
481482
if rand is not None:
483+
is_on_gpu([code, A, out, absmax, rand])
482484
assert blocksize==4096
483485
assert rand.numel() >= 1024
484486
rand_offset = random.randint(0, 1023)
@@ -489,6 +491,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
489491
else:
490492
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
491493
else:
494+
is_on_gpu([code, A, out, absmax])
492495
if A.dtype == torch.float32:
493496
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
494497
elif A.dtype == torch.float16:
@@ -499,6 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
499502
# cpu
500503
assert rand is None
501504
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
505+
post_call(A.device)
502506

503507
return out, (absmax, code)
504508

@@ -537,6 +541,7 @@ def dequantize_blockwise(
537541
Dequantized tensor (default: float32)
538542
"""
539543
assert quant_state is not None or absmax is not None
544+
device = pre_call(A.device)
540545
if code is None and quant_state is None:
541546
if "dynamic" not in name2qmap:
542547
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -561,6 +566,7 @@ def dequantize_blockwise(
561566
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
562567
else:
563568
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
569+
post_call(A.device)
564570

565571
return out
566572

0 commit comments

Comments
 (0)