@@ -429,8 +429,64 @@ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
429429    return  (blocks .view (torch .int16 ).to (torch .int32 ) <<  16 ).view (torch .float32 )
430430
431431
432+ # this part from calcuis (gguf.org) 
433+ # more info: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py 
434+ 
435+ 
436+ def  dequantize_blocks_IQ4_NL (blocks , block_size , type_size , dtype = None ):
437+     kvalues  =  torch .tensor (
438+         [- 127 , - 104 , - 83 , - 65 , - 49 , - 35 , - 22 , - 10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 ],
439+         dtype = torch .float32 ,
440+         device = blocks .device ,
441+     )
442+     n_blocks  =  blocks .shape [0 ]
443+     d , qs  =  split_block_dims (blocks , 2 )
444+     d  =  d .view (torch .float16 ).to (dtype )
445+     qs  =  qs .reshape ((n_blocks , - 1 , 1 , block_size  //  2 )) >>  torch .tensor (
446+         [0 , 4 ], device = blocks .device , dtype = torch .uint8 
447+     ).reshape ((1 , 1 , 2 , 1 ))
448+     qs  =  (qs  &  15 ).reshape ((n_blocks , - 1 )).to (torch .int64 )
449+     kvalues  =  kvalues .view (1 , 1 , 16 )
450+     qs  =  qs .unsqueeze (- 1 )
451+     qs  =  torch .gather (kvalues .expand (qs .shape [0 ], qs .shape [1 ], 16 ), 2 , qs )
452+     qs  =  qs .squeeze (- 1 ).to (dtype )
453+     return  d  *  qs 
454+ 
455+ 
456+ def  dequantize_blocks_IQ4_XS (blocks , block_size , type_size , dtype = None ):
457+     kvalues  =  torch .tensor (
458+         [- 127 , - 104 , - 83 , - 65 , - 49 , - 35 , - 22 , - 10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 ],
459+         dtype = torch .float32 ,
460+         device = blocks .device ,
461+     )
462+     n_blocks  =  blocks .shape [0 ]
463+     d , scales_h , scales_l , qs  =  split_block_dims (blocks , 2 , 2 , QK_K  //  64 )
464+     d  =  d .view (torch .float16 ).to (dtype )
465+     scales_h  =  scales_h .view (torch .int16 )
466+     scales_l  =  scales_l .reshape ((n_blocks , - 1 , 1 )) >>  torch .tensor (
467+         [0 , 4 ], device = blocks .device , dtype = torch .uint8 
468+     ).reshape ((1 , 1 , 2 ))
469+     scales_h  =  scales_h .reshape ((n_blocks , 1 , - 1 )) >>  torch .tensor (
470+         [2  *  i  for  i  in  range (QK_K  //  32 )], device = blocks .device , dtype = torch .uint8 
471+     ).reshape ((1 , - 1 , 1 ))
472+     scales_l  =  scales_l .reshape ((n_blocks , - 1 )) &  0x0F 
473+     scales_h  =  scales_h .reshape ((n_blocks , - 1 )) &  0x03 
474+     scales  =  (scales_l  |  (scales_h  <<  4 )) -  32 
475+     dl  =  (d  *  scales .to (dtype )).reshape ((n_blocks , - 1 , 1 ))
476+     shifts_q  =  torch .tensor ([0 , 4 ], device = blocks .device , dtype = torch .uint8 ).reshape (1 , 1 , 2 , 1 )
477+     qs  =  qs .reshape ((n_blocks , - 1 , 1 , 16 )) >>  shifts_q 
478+     qs  =  (qs  &  15 ).reshape ((n_blocks , - 1 , 32 )).to (torch .int64 )
479+     kvalues  =  kvalues .view (1 , 1 , 1 , 16 )
480+     qs  =  qs .unsqueeze (- 1 )
481+     qs  =  torch .gather (kvalues .expand (qs .shape [0 ], qs .shape [1 ], qs .shape [2 ], 16 ), 3 , qs )
482+     qs  =  qs .squeeze (- 1 ).to (dtype )
483+     return  (dl  *  qs ).reshape (n_blocks , - 1 )
484+ 
485+ 
432486GGML_QUANT_SIZES  =  gguf .GGML_QUANT_SIZES 
433487dequantize_functions  =  {
488+     gguf .GGMLQuantizationType .IQ4_NL : dequantize_blocks_IQ4_NL ,
489+     gguf .GGMLQuantizationType .IQ4_XS : dequantize_blocks_IQ4_XS ,
434490    gguf .GGMLQuantizationType .BF16 : dequantize_blocks_BF16 ,
435491    gguf .GGMLQuantizationType .Q8_0 : dequantize_blocks_Q8_0 ,
436492    gguf .GGMLQuantizationType .Q5_1 : dequantize_blocks_Q5_1 ,
0 commit comments