Skip to content

Commit 972c51f

Browse files
Update create_dynamic_map to always return a float32 tensor
1 parent 86b6c37 commit 972c51f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

bitsandbytes/functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,14 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
389389
if signed
390390
else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1,
391391
)
392-
boundaries = torch.linspace(0.1, 1, fraction_items)
392+
boundaries = torch.linspace(0.1, 1, fraction_items, dtype=torch.float32)
393393
means = (boundaries[:-1] + boundaries[1:]) / 2.0
394394
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
395395
if signed:
396396
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
397397

398398
if additional_items > 0:
399-
boundaries = torch.linspace(0.1, 1, additional_items + 1)
399+
boundaries = torch.linspace(0.1, 1, additional_items + 1, dtype=torch.float32)
400400
means = (boundaries[:-1] + boundaries[1:]) / 2.0
401401
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
402402
if signed:
@@ -412,7 +412,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
412412
data.append(0)
413413

414414
data.sort()
415-
return torch.tensor(data)
415+
return torch.tensor(data, dtype=torch.float32)
416416

417417

418418
def create_quantile_map(A, total_bits=8):

0 commit comments

Comments
 (0)