Skip to content

Commit a72c463

Browse files
cleanup
1 parent 35dbb2e commit a72c463

File tree

1 file changed

+92
-172
lines changed

1 file changed

+92
-172
lines changed

bitsandbytes/functional.py

Lines changed: 92 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -828,13 +828,13 @@ def __eq__(self, other):
828828

829829

830830
def quantize_blockwise(
831-
A: Tensor,
831+
A: torch.Tensor,
832832
code: Optional[torch.Tensor] = None,
833833
absmax: Optional[torch.Tensor] = None,
834834
out: Optional[torch.Tensor] = None,
835835
blocksize=4096,
836836
nested=False,
837-
) -> Tuple[Tensor, QuantState]:
837+
) -> Tuple[torch.Tensor, QuantState]:
838838
"""
839839
Quantize tensor A in blocks of size 4096 values.
840840
@@ -878,21 +878,11 @@ def quantize_blockwise(
878878
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
879879

880880
code = code.to(A.device)
881-
is_on_gpu([code, A, out, absmax])
882-
883-
fn_map = {
884-
torch.float32: "cquantize_blockwise_fp32",
885-
torch.bfloat16: "cquantize_blockwise_bf16",
886-
torch.float16: "cquantize_blockwise_fp16",
887-
}
888881

889-
if A.dtype not in fn_map.keys():
890-
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
891-
892-
fn = fn_map[A.dtype]
882+
is_on_gpu([A, out, absmax])
893883

894884
with torch.cuda.device_of(A):
895-
lib[fn](
885+
args = (
896886
get_ptr(code),
897887
get_ptr(A),
898888
get_ptr(absmax),
@@ -901,6 +891,15 @@ def quantize_blockwise(
901891
ct.c_int(A.numel()),
902892
)
903893

894+
if A.dtype == torch.float16:
895+
lib.cquantize_blockwise_fp16(*args)
896+
elif A.dtype == torch.bfloat16:
897+
lib.cquantize_blockwise_bf16(*args)
898+
elif A.dtype == torch.float32:
899+
lib.cquantize_blockwise_fp32(*args)
900+
else:
901+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
902+
904903
else:
905904
# cpu
906905
code = code.cpu()
@@ -932,14 +931,14 @@ def quantize_blockwise(
932931

933932

934933
def dequantize_blockwise(
935-
A: Tensor,
934+
A: torch.Tensor,
936935
quant_state: Optional[QuantState] = None,
937936
absmax: Optional[torch.Tensor] = None,
938937
code: Optional[torch.Tensor] = None,
939938
out: Optional[torch.Tensor] = None,
940939
blocksize: int = 4096,
941940
nested=False,
942-
) -> Tensor:
941+
) -> torch.Tensor:
943942
"""
944943
Dequantizes blockwise quantized values.
945944
@@ -986,25 +985,15 @@ def dequantize_blockwise(
986985

987986
if A.device.type != "cpu":
988987
code = quant_state.code.to(A.device)
989-
if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
988+
if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]:
990989
raise ValueError(
991-
f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
990+
f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]",
992991
)
993-
is_on_gpu([A, absmax, out])
994992

995-
fn_map = {
996-
torch.float32: "cdequantize_blockwise_fp32",
997-
torch.bfloat16: "cdequantize_blockwise_bf16",
998-
torch.float16: "cdequantize_blockwise_fp16",
999-
}
1000-
1001-
if out.dtype not in fn_map.keys():
1002-
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
1003-
1004-
fn = fn_map[out.dtype]
993+
is_on_gpu([A, absmax, out])
1005994

1006995
with torch.cuda.device_of(A):
1007-
lib[fn](
996+
args = (
1008997
get_ptr(quant_state.code),
1009998
get_ptr(A),
1010999
get_ptr(absmax),
@@ -1013,6 +1002,15 @@ def dequantize_blockwise(
10131002
ct.c_int(A.numel()),
10141003
_get_tensor_stream(A),
10151004
)
1005+
1006+
if out.dtype == torch.float16:
1007+
lib.cdequantize_blockwise_fp16(*args)
1008+
elif out.dtype == torch.bfloat16:
1009+
lib.cdequantize_blockwise_bf16(*args)
1010+
elif out.dtype == torch.float32:
1011+
lib.cdequantize_blockwise_fp32(*args)
1012+
else:
1013+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
10161014
else:
10171015
code = quant_state.code.cpu()
10181016
lib.cdequantize_blockwise_cpu_fp32(
@@ -1110,7 +1108,7 @@ def get_4bit_type(typename, device=None, blocksize=64):
11101108

11111109

11121110
def quantize_fp4(
1113-
A: Tensor,
1111+
A: torch.Tensor,
11141112
absmax: Optional[torch.Tensor] = None,
11151113
out: Optional[torch.Tensor] = None,
11161114
blocksize=64,
@@ -1121,7 +1119,7 @@ def quantize_fp4(
11211119

11221120

11231121
def quantize_nf4(
1124-
A: Tensor,
1122+
A: torch.Tensor,
11251123
absmax: Optional[torch.Tensor] = None,
11261124
out: Optional[torch.Tensor] = None,
11271125
blocksize=64,
@@ -1132,14 +1130,14 @@ def quantize_nf4(
11321130

11331131

11341132
def quantize_4bit(
1135-
A: Tensor,
1133+
A: torch.Tensor,
11361134
absmax: Optional[torch.Tensor] = None,
11371135
out: Optional[torch.Tensor] = None,
11381136
blocksize=64,
11391137
compress_statistics=False,
11401138
quant_type="fp4",
11411139
quant_storage=torch.uint8,
1142-
) -> Tuple[Tensor, QuantState]:
1140+
) -> Tuple[torch.Tensor, QuantState]:
11431141
"""
11441142
Quantize tensor A in blocks of 4-bit values.
11451143
@@ -1184,71 +1182,34 @@ def quantize_4bit(
11841182
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
11851183

11861184
is_on_gpu([A, out, absmax])
1187-
if A.dtype == torch.float32:
1188-
if quant_type == "fp4":
1189-
with torch.cuda.device_of(A):
1190-
lib.cquantize_blockwise_fp32_fp4(
1191-
get_ptr(None),
1192-
get_ptr(A),
1193-
get_ptr(absmax),
1194-
get_ptr(out),
1195-
ct.c_int32(blocksize),
1196-
ct.c_int(n),
1197-
)
1198-
else:
1199-
with torch.cuda.device_of(A):
1200-
lib.cquantize_blockwise_fp32_nf4(
1201-
get_ptr(None),
1202-
get_ptr(A),
1203-
get_ptr(absmax),
1204-
get_ptr(out),
1205-
ct.c_int32(blocksize),
1206-
ct.c_int(n),
1207-
)
1208-
elif A.dtype == torch.float16:
1209-
if quant_type == "fp4":
1210-
with torch.cuda.device_of(A):
1211-
lib.cquantize_blockwise_fp16_fp4(
1212-
get_ptr(None),
1213-
get_ptr(A),
1214-
get_ptr(absmax),
1215-
get_ptr(out),
1216-
ct.c_int32(blocksize),
1217-
ct.c_int(n),
1218-
)
1219-
else:
1220-
with torch.cuda.device_of(A):
1221-
lib.cquantize_blockwise_fp16_nf4(
1222-
get_ptr(None),
1223-
get_ptr(A),
1224-
get_ptr(absmax),
1225-
get_ptr(out),
1226-
ct.c_int32(blocksize),
1227-
ct.c_int(n),
1228-
)
1229-
elif A.dtype == torch.bfloat16:
1230-
if quant_type == "fp4":
1231-
with torch.cuda.device_of(A):
1232-
lib.cquantize_blockwise_bf16_fp4(
1233-
get_ptr(None),
1234-
get_ptr(A),
1235-
get_ptr(absmax),
1236-
get_ptr(out),
1237-
ct.c_int32(blocksize),
1238-
ct.c_int(n),
1239-
)
1185+
1186+
with torch.cuda.device_of(A):
1187+
args = (
1188+
get_ptr(None),
1189+
get_ptr(A),
1190+
get_ptr(absmax),
1191+
get_ptr(out),
1192+
ct.c_int32(blocksize),
1193+
ct.c_int(n),
1194+
)
1195+
1196+
if A.dtype == torch.bfloat16:
1197+
if quant_type == "fp4":
1198+
lib.cquantize_blockwise_bf16_fp4(*args)
1199+
else:
1200+
lib.cquantize_blockwise_bf16_nf4(*args)
1201+
elif A.dtype == torch.float16:
1202+
if quant_type == "fp4":
1203+
lib.cquantize_blockwise_fp16_fp4(*args)
1204+
else:
1205+
lib.cquantize_blockwise_fp16_nf4(*args)
1206+
elif A.dtype == torch.float32:
1207+
if quant_type == "fp4":
1208+
lib.cquantize_blockwise_fp32_fp4(*args)
1209+
else:
1210+
lib.cquantize_blockwise_fp32_nf4(*args)
12401211
else:
1241-
with torch.cuda.device_of(A):
1242-
lib.cquantize_blockwise_bf16_nf4(
1243-
get_ptr(None),
1244-
get_ptr(A),
1245-
get_ptr(absmax),
1246-
get_ptr(out),
1247-
ct.c_int32(blocksize),
1248-
ct.c_int(n),
1249-
)
1250-
else:
1251-
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
1212+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
12521213

12531214
code = get_4bit_type(quant_type, device=A.device)
12541215

@@ -1281,33 +1242,33 @@ def quantize_4bit(
12811242

12821243

12831244
def dequantize_fp4(
1284-
A: Tensor,
1245+
A: torch.Tensor,
12851246
quant_state: Optional[QuantState] = None,
12861247
absmax: Optional[torch.Tensor] = None,
12871248
out: Optional[torch.Tensor] = None,
12881249
blocksize: int = 64,
1289-
) -> Tensor:
1250+
) -> torch.Tensor:
12901251
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
12911252

12921253

12931254
def dequantize_nf4(
1294-
A: Tensor,
1255+
A: torch.Tensor,
12951256
quant_state: Optional[QuantState] = None,
12961257
absmax: Optional[torch.Tensor] = None,
12971258
out: Optional[torch.Tensor] = None,
12981259
blocksize: int = 64,
1299-
) -> Tensor:
1260+
) -> torch.Tensor:
13001261
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
13011262

13021263

13031264
def dequantize_4bit(
1304-
A: Tensor,
1265+
A: torch.Tensor,
13051266
quant_state: Optional[QuantState] = None,
13061267
absmax: Optional[torch.Tensor] = None,
13071268
out: Optional[torch.Tensor] = None,
13081269
blocksize: int = 64,
13091270
quant_type="fp4",
1310-
) -> Tensor:
1271+
) -> torch.Tensor:
13111272
"""
13121273
Dequantizes FP4 blockwise quantized values.
13131274
@@ -1368,76 +1329,35 @@ def dequantize_4bit(
13681329

13691330
is_on_gpu([A, absmax, out])
13701331
stream = _get_tensor_stream(A)
1371-
if out.dtype == torch.float32:
1372-
if quant_state.quant_type == "fp4":
1373-
with torch.cuda.device_of(A):
1374-
lib.cdequantize_blockwise_fp32_fp4(
1375-
get_ptr(None),
1376-
get_ptr(A),
1377-
get_ptr(absmax),
1378-
get_ptr(out),
1379-
ct.c_int(quant_state.blocksize),
1380-
ct.c_int(n),
1381-
stream,
1382-
)
1383-
else:
1384-
with torch.cuda.device_of(A):
1385-
lib.cdequantize_blockwise_fp32_nf4(
1386-
get_ptr(None),
1387-
get_ptr(A),
1388-
get_ptr(absmax),
1389-
get_ptr(out),
1390-
ct.c_int(quant_state.blocksize),
1391-
ct.c_int(n),
1392-
stream,
1393-
)
1394-
elif out.dtype == torch.float16:
1395-
if quant_state.quant_type == "fp4":
1396-
with torch.cuda.device_of(A):
1397-
lib.cdequantize_blockwise_fp16_fp4(
1398-
get_ptr(None),
1399-
get_ptr(A),
1400-
get_ptr(absmax),
1401-
get_ptr(out),
1402-
ct.c_int(quant_state.blocksize),
1403-
ct.c_int(n),
1404-
stream,
1405-
)
1406-
else:
1407-
with torch.cuda.device_of(A):
1408-
lib.cdequantize_blockwise_fp16_nf4(
1409-
get_ptr(None),
1410-
get_ptr(A),
1411-
get_ptr(absmax),
1412-
get_ptr(out),
1413-
ct.c_int(quant_state.blocksize),
1414-
ct.c_int(n),
1415-
stream,
1416-
)
1417-
elif out.dtype == torch.bfloat16:
1418-
with torch.cuda.device_of(A):
1332+
1333+
with torch.cuda.device_of(A):
1334+
args = (
1335+
get_ptr(None),
1336+
get_ptr(A),
1337+
get_ptr(absmax),
1338+
get_ptr(out),
1339+
ct.c_int(quant_state.blocksize),
1340+
ct.c_int(n),
1341+
stream,
1342+
)
1343+
1344+
if out.dtype == torch.bfloat16:
14191345
if quant_state.quant_type == "fp4":
1420-
lib.cdequantize_blockwise_bf16_fp4(
1421-
get_ptr(None),
1422-
get_ptr(A),
1423-
get_ptr(absmax),
1424-
get_ptr(out),
1425-
ct.c_int(quant_state.blocksize),
1426-
ct.c_int(n),
1427-
stream,
1428-
)
1346+
lib.cdequantize_blockwise_bf16_fp4(*args)
14291347
else:
1430-
lib.cdequantize_blockwise_bf16_nf4(
1431-
get_ptr(None),
1432-
get_ptr(A),
1433-
get_ptr(absmax),
1434-
get_ptr(out),
1435-
ct.c_int(quant_state.blocksize),
1436-
ct.c_int(n),
1437-
stream,
1438-
)
1439-
else:
1440-
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
1348+
lib.cdequantize_blockwise_bf16_nf4(*args)
1349+
elif out.dtype == torch.float16:
1350+
if quant_state.quant_type == "fp4":
1351+
lib.cdequantize_blockwise_fp16_fp4(*args)
1352+
else:
1353+
lib.cdequantize_blockwise_fp16_nf4(*args)
1354+
elif out.dtype == torch.float32:
1355+
if quant_state.quant_type == "fp4":
1356+
lib.cdequantize_blockwise_fp32_fp4(*args)
1357+
else:
1358+
lib.cdequantize_blockwise_fp32_nf4(*args)
1359+
else:
1360+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
14411361

14421362
if A.shape[0] == 1: # is transposed, transpose back
14431363
return out.t()

0 commit comments

Comments
 (0)