@@ -466,7 +466,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
466466
467467 if absmax is None :
468468 n = A .numel ()
469- blocksize = (blocksize if A .device .type == 'cpu ' else 4096 )
469+ blocksize = (blocksize if A .device .type == 'cuda ' else 4096 )
470470 blocks = n // blocksize
471471 blocks += 1 if n % blocksize > 0 else 0
472472 absmax = torch .zeros ((blocks ,), device = A .device )
@@ -550,17 +550,15 @@ def dequantize_blockwise(
550550
551551
552552 if A .device .type != 'cpu' :
553- if blocksize not in [2048 , 4096 ]:
554- raise ValueError (f"The blockwise of { blocksize } is not supported. Supported values: [2048 4096]" )
553+ if blocksize not in [2048 , 4096 , 1024 , 512 ]:
554+ raise ValueError (f"The blockwise of { blocksize } is not supported. Supported values: [2048, 4096, 1024, 512 ]" )
555555 is_on_gpu ([A , out ])
556556 if out .dtype == torch .float32 :
557557 lib .cdequantize_blockwise_fp32 (get_ptr (quant_state [1 ]), get_ptr (A ), get_ptr (quant_state [0 ]), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
558558 elif out .dtype == torch .float16 :
559559 lib .cdequantize_blockwise_fp16 (get_ptr (quant_state [1 ]), get_ptr (A ), get_ptr (quant_state [0 ]), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
560560 else :
561- raise ValueError (
562- f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } "
563- )
561+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
564562 else :
565563 lib .cdequantize_blockwise_cpu_fp32 (get_ptr (quant_state [1 ]), get_ptr (A ), get_ptr (quant_state [0 ]), get_ptr (out ), ct .c_longlong (blocksize ), ct .c_longlong (A .numel ()))
566564
0 commit comments