@@ -1579,22 +1579,22 @@ def gemv_4bit(
15791579 out : Tensor = None ,
15801580 transposed_A = False ,
15811581 transposed_B = False ,
1582- quant_state = None
1582+ state = None
15831583):
15841584 prev_device = pre_call (A .device )
15851585 #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
1586- if quant_state is None :
1586+ if state is None :
15871587 raise ValueError (f'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )' )
15881588
15891589 if A .numel () != A .shape [- 1 ]:
15901590 raise ValueError (f'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]' )
15911591
1592- Bshape = quant_state .shape
1592+ Bshape = state .shape
15931593 bout = Bshape [0 ]
1594- absmax = quant_state .absmax
1595- if quant_state .nested :
1596- absmax = dequantize_blockwise (quant_state .absmax , quant_state .state2 )
1597- absmax += quant_state .offset
1594+ absmax = state .absmax
1595+ if state .nested :
1596+ absmax = dequantize_blockwise (state .absmax , state .state2 )
1597+ absmax += state .offset
15981598
15991599 if out is None :
16001600 if len (A .shape ) == 3 :
@@ -1608,7 +1608,7 @@ def gemv_4bit(
16081608 lda = Bshape [0 ]
16091609 ldc = Bshape [0 ]
16101610 ldb = (A .shape [- 1 ]+ 1 )// 2
1611- is_on_gpu ([B , A , out , absmax , quant_state .code ])
1611+ is_on_gpu ([B , A , out , absmax , state .code ])
16121612 m = ct .c_int32 (m )
16131613 n = ct .c_int32 (n )
16141614 k = ct .c_int32 (k )
@@ -1618,11 +1618,11 @@ def gemv_4bit(
16181618
16191619 if B .dtype == torch .uint8 :
16201620 if A .dtype == torch .float16 :
1621- lib .cgemm_4bit_inference_naive_fp16 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (quant_state .code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (quant_state .blocksize ))
1621+ lib .cgemm_4bit_inference_naive_fp16 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (state .code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (state .blocksize ))
16221622 elif A .dtype == torch .bfloat16 :
1623- lib .cgemm_4bit_inference_naive_bf16 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (quant_state .code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (quant_state .blocksize ))
1623+ lib .cgemm_4bit_inference_naive_bf16 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (state .code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (state .blocksize ))
16241624 elif A .dtype == torch .float32 :
1625- lib .cgemm_4bit_inference_naive_fp32 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (quant_state .code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (quant_state .blocksize ))
1625+ lib .cgemm_4bit_inference_naive_fp32 (m , n , k , get_ptr (A ), get_ptr (B ), get_ptr (absmax ), get_ptr (state .code ), get_ptr (out ), lda , ldb , ldc , ct .c_int32 (state .blocksize ))
16261626 else :
16271627 raise NotImplementedError (f'Matmul not implemented for data type { A .dtype } ' )
16281628
@@ -1904,7 +1904,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
19041904
19051905def mm_dequant (
19061906 A ,
1907- state ,
1907+ quant_state ,
19081908 row_stats ,
19091909 col_stats ,
19101910 out = None ,
@@ -1914,7 +1914,7 @@ def mm_dequant(
19141914):
19151915 assert A .dtype == torch .int32
19161916 if bias is not None : assert bias .dtype == torch .float16
1917- out_shape = state [0 ]
1917+ out_shape = quant_state [0 ]
19181918 if len (out_shape ) == 3 :
19191919 out_shape = (out_shape [0 ] * out_shape [1 ], out_shape [2 ])
19201920
0 commit comments