Skip to content

Commit daad33d

Browse files
committed
keep cuda op
Signed-off-by: jiqing-feng <[email protected]>
1 parent 50ee994 commit daad33d

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

bitsandbytes/_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ def _(
225225

226226
n = A.numel()
227227
blocks = -(n // -blocksize)
228-
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
228+
dtype = torch.float32 if torch.cuda.is_available() else A.dtype
229+
absmax = torch.empty((blocks,), device=A.device, dtype=dtype)
229230
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
230231
return out, absmax
231232

@@ -268,7 +269,8 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
268269
torch._check_is_size(blocksize)
269270
n = A.numel()
270271
blocks = -(n // -blocksize)
271-
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
272+
dtype = torch.float32 if torch.cuda.is_available() else A.dtype
273+
absmax = torch.empty((blocks,), device=A.device, dtype=dtype)
272274
out = torch.empty_like(A, dtype=torch.uint8)
273275
return out, absmax
274276

0 commit comments

Comments
 (0)