@@ -828,13 +828,13 @@ def __eq__(self, other):
828828
829829
830830def quantize_blockwise (
831- A : Tensor ,
831+ A : torch . Tensor ,
832832 code : Optional [torch .Tensor ] = None ,
833833 absmax : Optional [torch .Tensor ] = None ,
834834 out : Optional [torch .Tensor ] = None ,
835835 blocksize = 4096 ,
836836 nested = False ,
837- ) -> Tuple [Tensor , QuantState ]:
837+ ) -> Tuple [torch . Tensor , QuantState ]:
838838 """
839839 Quantize tensor A in blocks of size 4096 values.
840840
@@ -878,21 +878,11 @@ def quantize_blockwise(
878878 assert blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ]
879879
880880 code = code .to (A .device )
881- is_on_gpu ([code , A , out , absmax ])
882-
883- fn_map = {
884- torch .float32 : "cquantize_blockwise_fp32" ,
885- torch .bfloat16 : "cquantize_blockwise_bf16" ,
886- torch .float16 : "cquantize_blockwise_fp16" ,
887- }
888881
889- if A .dtype not in fn_map .keys ():
890- raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
891-
892- fn = fn_map [A .dtype ]
882+ is_on_gpu ([A , out , absmax ])
893883
894884 with torch .cuda .device_of (A ):
895- lib [ fn ] (
885+ args = (
896886 get_ptr (code ),
897887 get_ptr (A ),
898888 get_ptr (absmax ),
@@ -901,6 +891,15 @@ def quantize_blockwise(
901891 ct .c_int (A .numel ()),
902892 )
903893
894+ if A .dtype == torch .float16 :
895+ lib .cquantize_blockwise_fp16 (* args )
896+ elif A .dtype == torch .bfloat16 :
897+ lib .cquantize_blockwise_bf16 (* args )
898+ elif A .dtype == torch .float32 :
899+ lib .cquantize_blockwise_fp32 (* args )
900+ else :
901+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
902+
904903 else :
905904 # cpu
906905 code = code .cpu ()
@@ -932,14 +931,14 @@ def quantize_blockwise(
932931
933932
934933def dequantize_blockwise (
935- A : Tensor ,
934+ A : torch . Tensor ,
936935 quant_state : Optional [QuantState ] = None ,
937936 absmax : Optional [torch .Tensor ] = None ,
938937 code : Optional [torch .Tensor ] = None ,
939938 out : Optional [torch .Tensor ] = None ,
940939 blocksize : int = 4096 ,
941940 nested = False ,
942- ) -> Tensor :
941+ ) -> torch . Tensor :
943942 """
944943 Dequantizes blockwise quantized values.
945944
@@ -986,25 +985,15 @@ def dequantize_blockwise(
986985
987986 if A .device .type != "cpu" :
988987 code = quant_state .code .to (A .device )
989- if quant_state .blocksize not in [2048 , 4096 , 1024 , 512 , 256 , 128 , 64 ]:
988+ if quant_state .blocksize not in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ]:
990989 raise ValueError (
991- f"The blockwise of { quant_state .blocksize } is not supported. Supported values: [2048, 4096 , 1024, 512, 256, 128, 64]" ,
990+ f"The blocksize of { quant_state .blocksize } is not supported. Supported values: [4096, 2048 , 1024, 512, 256, 128, 64]" ,
992991 )
993- is_on_gpu ([A , absmax , out ])
994992
995- fn_map = {
996- torch .float32 : "cdequantize_blockwise_fp32" ,
997- torch .bfloat16 : "cdequantize_blockwise_bf16" ,
998- torch .float16 : "cdequantize_blockwise_fp16" ,
999- }
1000-
1001- if out .dtype not in fn_map .keys ():
1002- raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { out .dtype } " )
1003-
1004- fn = fn_map [out .dtype ]
993+ is_on_gpu ([A , absmax , out ])
1005994
1006995 with torch .cuda .device_of (A ):
1007- lib [ fn ] (
996+ args = (
1008997 get_ptr (quant_state .code ),
1009998 get_ptr (A ),
1010999 get_ptr (absmax ),
@@ -1013,6 +1002,15 @@ def dequantize_blockwise(
10131002 ct .c_int (A .numel ()),
10141003 _get_tensor_stream (A ),
10151004 )
1005+
1006+ if out .dtype == torch .float16 :
1007+ lib .cdequantize_blockwise_fp16 (* args )
1008+ elif out .dtype == torch .bfloat16 :
1009+ lib .cdequantize_blockwise_bf16 (* args )
1010+ elif out .dtype == torch .float32 :
1011+ lib .cdequantize_blockwise_fp32 (* args )
1012+ else :
1013+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { out .dtype } " )
10161014 else :
10171015 code = quant_state .code .cpu ()
10181016 lib .cdequantize_blockwise_cpu_fp32 (
@@ -1110,7 +1108,7 @@ def get_4bit_type(typename, device=None, blocksize=64):
11101108
11111109
11121110def quantize_fp4 (
1113- A : Tensor ,
1111+ A : torch . Tensor ,
11141112 absmax : Optional [torch .Tensor ] = None ,
11151113 out : Optional [torch .Tensor ] = None ,
11161114 blocksize = 64 ,
@@ -1121,7 +1119,7 @@ def quantize_fp4(
11211119
11221120
11231121def quantize_nf4 (
1124- A : Tensor ,
1122+ A : torch . Tensor ,
11251123 absmax : Optional [torch .Tensor ] = None ,
11261124 out : Optional [torch .Tensor ] = None ,
11271125 blocksize = 64 ,
@@ -1132,14 +1130,14 @@ def quantize_nf4(
11321130
11331131
11341132def quantize_4bit (
1135- A : Tensor ,
1133+ A : torch . Tensor ,
11361134 absmax : Optional [torch .Tensor ] = None ,
11371135 out : Optional [torch .Tensor ] = None ,
11381136 blocksize = 64 ,
11391137 compress_statistics = False ,
11401138 quant_type = "fp4" ,
11411139 quant_storage = torch .uint8 ,
1142- ) -> Tuple [Tensor , QuantState ]:
1140+ ) -> Tuple [torch . Tensor , QuantState ]:
11431141 """
11441142 Quantize tensor A in blocks of 4-bit values.
11451143
@@ -1184,71 +1182,34 @@ def quantize_4bit(
11841182 assert blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ]
11851183
11861184 is_on_gpu ([A , out , absmax ])
1187- if A .dtype == torch .float32 :
1188- if quant_type == "fp4" :
1189- with torch .cuda .device_of (A ):
1190- lib .cquantize_blockwise_fp32_fp4 (
1191- get_ptr (None ),
1192- get_ptr (A ),
1193- get_ptr (absmax ),
1194- get_ptr (out ),
1195- ct .c_int32 (blocksize ),
1196- ct .c_int (n ),
1197- )
1198- else :
1199- with torch .cuda .device_of (A ):
1200- lib .cquantize_blockwise_fp32_nf4 (
1201- get_ptr (None ),
1202- get_ptr (A ),
1203- get_ptr (absmax ),
1204- get_ptr (out ),
1205- ct .c_int32 (blocksize ),
1206- ct .c_int (n ),
1207- )
1208- elif A .dtype == torch .float16 :
1209- if quant_type == "fp4" :
1210- with torch .cuda .device_of (A ):
1211- lib .cquantize_blockwise_fp16_fp4 (
1212- get_ptr (None ),
1213- get_ptr (A ),
1214- get_ptr (absmax ),
1215- get_ptr (out ),
1216- ct .c_int32 (blocksize ),
1217- ct .c_int (n ),
1218- )
1219- else :
1220- with torch .cuda .device_of (A ):
1221- lib .cquantize_blockwise_fp16_nf4 (
1222- get_ptr (None ),
1223- get_ptr (A ),
1224- get_ptr (absmax ),
1225- get_ptr (out ),
1226- ct .c_int32 (blocksize ),
1227- ct .c_int (n ),
1228- )
1229- elif A .dtype == torch .bfloat16 :
1230- if quant_type == "fp4" :
1231- with torch .cuda .device_of (A ):
1232- lib .cquantize_blockwise_bf16_fp4 (
1233- get_ptr (None ),
1234- get_ptr (A ),
1235- get_ptr (absmax ),
1236- get_ptr (out ),
1237- ct .c_int32 (blocksize ),
1238- ct .c_int (n ),
1239- )
1185+
1186+ with torch .cuda .device_of (A ):
1187+ args = (
1188+ get_ptr (None ),
1189+ get_ptr (A ),
1190+ get_ptr (absmax ),
1191+ get_ptr (out ),
1192+ ct .c_int32 (blocksize ),
1193+ ct .c_int (n ),
1194+ )
1195+
1196+ if A .dtype == torch .bfloat16 :
1197+ if quant_type == "fp4" :
1198+ lib .cquantize_blockwise_bf16_fp4 (* args )
1199+ else :
1200+ lib .cquantize_blockwise_bf16_nf4 (* args )
1201+ elif A .dtype == torch .float16 :
1202+ if quant_type == "fp4" :
1203+ lib .cquantize_blockwise_fp16_fp4 (* args )
1204+ else :
1205+ lib .cquantize_blockwise_fp16_nf4 (* args )
1206+ elif A .dtype == torch .float32 :
1207+ if quant_type == "fp4" :
1208+ lib .cquantize_blockwise_fp32_fp4 (* args )
1209+ else :
1210+ lib .cquantize_blockwise_fp32_nf4 (* args )
12401211 else :
1241- with torch .cuda .device_of (A ):
1242- lib .cquantize_blockwise_bf16_nf4 (
1243- get_ptr (None ),
1244- get_ptr (A ),
1245- get_ptr (absmax ),
1246- get_ptr (out ),
1247- ct .c_int32 (blocksize ),
1248- ct .c_int (n ),
1249- )
1250- else :
1251- raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
1212+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
12521213
12531214 code = get_4bit_type (quant_type , device = A .device )
12541215
@@ -1281,33 +1242,33 @@ def quantize_4bit(
12811242
12821243
12831244def dequantize_fp4 (
1284- A : Tensor ,
1245+ A : torch . Tensor ,
12851246 quant_state : Optional [QuantState ] = None ,
12861247 absmax : Optional [torch .Tensor ] = None ,
12871248 out : Optional [torch .Tensor ] = None ,
12881249 blocksize : int = 64 ,
1289- ) -> Tensor :
1250+ ) -> torch . Tensor :
12901251 return dequantize_4bit (A , quant_state , absmax , out , blocksize , "fp4" )
12911252
12921253
12931254def dequantize_nf4 (
1294- A : Tensor ,
1255+ A : torch . Tensor ,
12951256 quant_state : Optional [QuantState ] = None ,
12961257 absmax : Optional [torch .Tensor ] = None ,
12971258 out : Optional [torch .Tensor ] = None ,
12981259 blocksize : int = 64 ,
1299- ) -> Tensor :
1260+ ) -> torch . Tensor :
13001261 return dequantize_4bit (A , quant_state , absmax , out , blocksize , "nf4" )
13011262
13021263
13031264def dequantize_4bit (
1304- A : Tensor ,
1265+ A : torch . Tensor ,
13051266 quant_state : Optional [QuantState ] = None ,
13061267 absmax : Optional [torch .Tensor ] = None ,
13071268 out : Optional [torch .Tensor ] = None ,
13081269 blocksize : int = 64 ,
13091270 quant_type = "fp4" ,
1310- ) -> Tensor :
1271+ ) -> torch . Tensor :
13111272 """
13121273 Dequantizes FP4 blockwise quantized values.
13131274
@@ -1368,76 +1329,35 @@ def dequantize_4bit(
13681329
13691330 is_on_gpu ([A , absmax , out ])
13701331 stream = _get_tensor_stream (A )
1371- if out .dtype == torch .float32 :
1372- if quant_state .quant_type == "fp4" :
1373- with torch .cuda .device_of (A ):
1374- lib .cdequantize_blockwise_fp32_fp4 (
1375- get_ptr (None ),
1376- get_ptr (A ),
1377- get_ptr (absmax ),
1378- get_ptr (out ),
1379- ct .c_int (quant_state .blocksize ),
1380- ct .c_int (n ),
1381- stream ,
1382- )
1383- else :
1384- with torch .cuda .device_of (A ):
1385- lib .cdequantize_blockwise_fp32_nf4 (
1386- get_ptr (None ),
1387- get_ptr (A ),
1388- get_ptr (absmax ),
1389- get_ptr (out ),
1390- ct .c_int (quant_state .blocksize ),
1391- ct .c_int (n ),
1392- stream ,
1393- )
1394- elif out .dtype == torch .float16 :
1395- if quant_state .quant_type == "fp4" :
1396- with torch .cuda .device_of (A ):
1397- lib .cdequantize_blockwise_fp16_fp4 (
1398- get_ptr (None ),
1399- get_ptr (A ),
1400- get_ptr (absmax ),
1401- get_ptr (out ),
1402- ct .c_int (quant_state .blocksize ),
1403- ct .c_int (n ),
1404- stream ,
1405- )
1406- else :
1407- with torch .cuda .device_of (A ):
1408- lib .cdequantize_blockwise_fp16_nf4 (
1409- get_ptr (None ),
1410- get_ptr (A ),
1411- get_ptr (absmax ),
1412- get_ptr (out ),
1413- ct .c_int (quant_state .blocksize ),
1414- ct .c_int (n ),
1415- stream ,
1416- )
1417- elif out .dtype == torch .bfloat16 :
1418- with torch .cuda .device_of (A ):
1332+
1333+ with torch .cuda .device_of (A ):
1334+ args = (
1335+ get_ptr (None ),
1336+ get_ptr (A ),
1337+ get_ptr (absmax ),
1338+ get_ptr (out ),
1339+ ct .c_int (quant_state .blocksize ),
1340+ ct .c_int (n ),
1341+ stream ,
1342+ )
1343+
1344+ if out .dtype == torch .bfloat16 :
14191345 if quant_state .quant_type == "fp4" :
1420- lib .cdequantize_blockwise_bf16_fp4 (
1421- get_ptr (None ),
1422- get_ptr (A ),
1423- get_ptr (absmax ),
1424- get_ptr (out ),
1425- ct .c_int (quant_state .blocksize ),
1426- ct .c_int (n ),
1427- stream ,
1428- )
1346+ lib .cdequantize_blockwise_bf16_fp4 (* args )
14291347 else :
1430- lib .cdequantize_blockwise_bf16_nf4 (
1431- get_ptr (None ),
1432- get_ptr (A ),
1433- get_ptr (absmax ),
1434- get_ptr (out ),
1435- ct .c_int (quant_state .blocksize ),
1436- ct .c_int (n ),
1437- stream ,
1438- )
1439- else :
1440- raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
1348+ lib .cdequantize_blockwise_bf16_nf4 (* args )
1349+ elif out .dtype == torch .float16 :
1350+ if quant_state .quant_type == "fp4" :
1351+ lib .cdequantize_blockwise_fp16_fp4 (* args )
1352+ else :
1353+ lib .cdequantize_blockwise_fp16_nf4 (* args )
1354+ elif out .dtype == torch .float32 :
1355+ if quant_state .quant_type == "fp4" :
1356+ lib .cdequantize_blockwise_fp32_fp4 (* args )
1357+ else :
1358+ lib .cdequantize_blockwise_fp32_nf4 (* args )
1359+ else :
1360+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { out .dtype } " )
14411361
14421362 if A .shape [0 ] == 1 : # is transposed, transpose back
14431363 return out .t ()
0 commit comments