@@ -429,8 +429,64 @@ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
429429 return (blocks .view (torch .int16 ).to (torch .int32 ) << 16 ).view (torch .float32 )
430430
431431
432+ # this part from calcuis (gguf.org)
433+ # more info: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py
434+
435+
436+ def dequantize_blocks_IQ4_NL (blocks , block_size , type_size , dtype = None ):
437+ kvalues = torch .tensor (
438+ [- 127 , - 104 , - 83 , - 65 , - 49 , - 35 , - 22 , - 10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 ],
439+ dtype = torch .float32 ,
440+ device = blocks .device ,
441+ )
442+ n_blocks = blocks .shape [0 ]
443+ d , qs = split_block_dims (blocks , 2 )
444+ d = d .view (torch .float16 ).to (dtype )
445+ qs = qs .reshape ((n_blocks , - 1 , 1 , block_size // 2 )) >> torch .tensor (
446+ [0 , 4 ], device = blocks .device , dtype = torch .uint8
447+ ).reshape ((1 , 1 , 2 , 1 ))
448+ qs = (qs & 15 ).reshape ((n_blocks , - 1 )).to (torch .int64 )
449+ kvalues = kvalues .view (1 , 1 , 16 )
450+ qs = qs .unsqueeze (- 1 )
451+ qs = torch .gather (kvalues .expand (qs .shape [0 ], qs .shape [1 ], 16 ), 2 , qs )
452+ qs = qs .squeeze (- 1 ).to (dtype )
453+ return d * qs
454+
455+
456+ def dequantize_blocks_IQ4_XS (blocks , block_size , type_size , dtype = None ):
457+ kvalues = torch .tensor (
458+ [- 127 , - 104 , - 83 , - 65 , - 49 , - 35 , - 22 , - 10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 ],
459+ dtype = torch .float32 ,
460+ device = blocks .device ,
461+ )
462+ n_blocks = blocks .shape [0 ]
463+ d , scales_h , scales_l , qs = split_block_dims (blocks , 2 , 2 , QK_K // 64 )
464+ d = d .view (torch .float16 ).to (dtype )
465+ scales_h = scales_h .view (torch .int16 )
466+ scales_l = scales_l .reshape ((n_blocks , - 1 , 1 )) >> torch .tensor (
467+ [0 , 4 ], device = blocks .device , dtype = torch .uint8
468+ ).reshape ((1 , 1 , 2 ))
469+ scales_h = scales_h .reshape ((n_blocks , 1 , - 1 )) >> torch .tensor (
470+ [2 * i for i in range (QK_K // 32 )], device = blocks .device , dtype = torch .uint8
471+ ).reshape ((1 , - 1 , 1 ))
472+ scales_l = scales_l .reshape ((n_blocks , - 1 )) & 0x0F
473+ scales_h = scales_h .reshape ((n_blocks , - 1 )) & 0x03
474+ scales = (scales_l | (scales_h << 4 )) - 32
475+ dl = (d * scales .to (dtype )).reshape ((n_blocks , - 1 , 1 ))
476+ shifts_q = torch .tensor ([0 , 4 ], device = blocks .device , dtype = torch .uint8 ).reshape (1 , 1 , 2 , 1 )
477+ qs = qs .reshape ((n_blocks , - 1 , 1 , 16 )) >> shifts_q
478+ qs = (qs & 15 ).reshape ((n_blocks , - 1 , 32 )).to (torch .int64 )
479+ kvalues = kvalues .view (1 , 1 , 1 , 16 )
480+ qs = qs .unsqueeze (- 1 )
481+ qs = torch .gather (kvalues .expand (qs .shape [0 ], qs .shape [1 ], qs .shape [2 ], 16 ), 3 , qs )
482+ qs = qs .squeeze (- 1 ).to (dtype )
483+ return (dl * qs ).reshape (n_blocks , - 1 )
484+
485+
432486GGML_QUANT_SIZES = gguf .GGML_QUANT_SIZES
433487dequantize_functions = {
488+ gguf .GGMLQuantizationType .IQ4_NL : dequantize_blocks_IQ4_NL ,
489+ gguf .GGMLQuantizationType .IQ4_XS : dequantize_blocks_IQ4_XS ,
434490 gguf .GGMLQuantizationType .BF16 : dequantize_blocks_BF16 ,
435491 gguf .GGMLQuantizationType .Q8_0 : dequantize_blocks_Q8_0 ,
436492 gguf .GGMLQuantizationType .Q5_1 : dequantize_blocks_Q5_1 ,
0 commit comments