@@ -458,6 +458,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
458458 """
459459
460460
461+ prev_device = pre_call (A .device )
461462 if code is None :
462463 if "dynamic" not in name2qmap :
463464 name2qmap ["dynamic" ] = create_dynamic_map ().to (A .device )
@@ -479,6 +480,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
479480 is_on_gpu ([code , A , absmax , out , rand ])
480481 cblocksize = ct .c_int32 (blocksize )
481482 if rand is not None :
483+ is_on_gpu ([code , A , out , absmax , rand ])
482484 assert blocksize == 4096
483485 assert rand .numel () >= 1024
484486 rand_offset = random .randint (0 , 1023 )
@@ -489,6 +491,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
489491 else :
490492 raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
491493 else :
494+ is_on_gpu ([code , A , out , absmax ])
492495 if A .dtype == torch .float32 :
493496 lib .cquantize_blockwise_fp32 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), cblocksize , ct .c_int (A .numel ()))
494497 elif A .dtype == torch .float16 :
@@ -499,6 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
499502 # cpu
500503 assert rand is None
501504 lib .cquantize_blockwise_cpu_fp32 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_longlong (blocksize ), ct .c_longlong (A .numel ()))
505+ post_call (A .device )
502506
503507 return out , (absmax , code )
504508
@@ -537,6 +541,7 @@ def dequantize_blockwise(
537541 Dequantized tensor (default: float32)
538542 """
539543 assert quant_state is not None or absmax is not None
544+ device = pre_call (A .device )
540545 if code is None and quant_state is None :
541546 if "dynamic" not in name2qmap :
542547 name2qmap ["dynamic" ] = create_dynamic_map ().to (A .device )
@@ -561,6 +566,7 @@ def dequantize_blockwise(
561566 raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
562567 else :
563568 lib .cdequantize_blockwise_cpu_fp32 (get_ptr (quant_state [1 ]), get_ptr (A ), get_ptr (quant_state [0 ]), get_ptr (out ), ct .c_longlong (blocksize ), ct .c_longlong (A .numel ()))
569+ post_call (A .device )
564570
565571 return out
566572
0 commit comments