@@ -29,28 +29,30 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
2929class vLLMw8a8QuantizationMethod (vLLMBaseQuantizationMethod ):
3030 def __init__ (self ):
3131 super ().__init__ ()
32- self .input_scale = None
3332
3433 def quantize (self , weight : torch .Tensor ):
3534 if isinstance (weight , tuple ):
36- if len (weight ) == 3 :
37- self .input_scale = weight [- 1 ]
38- return weight [0 ].transpose (0 , 1 ).cuda (), weight [1 ]
35+ return (weight [0 ].transpose (0 , 1 ).cuda (),) + weight [1 :]
3936 weight = weight .float ()
4037 scale = weight .abs ().max (dim = - 1 )[0 ] / 127
4138 weight = weight .transpose (0 , 1 ) / scale .reshape (1 , - 1 )
4239 weight = torch .round (weight .clamp (min = - 128 , max = 127 )).to (dtype = torch .int8 )
4340 return weight .cuda (), scale .cuda ()
4441
4542 def apply (self , input_tensor , weights , bias = None , out = None , workspace = None ):
46- x_q , x_scale , x_zp = ops .scaled_int8_quant (input_tensor , scale = self .input_scale , azp = None , symmetric = True )
43+ input_scale = None
44+ if len (weights ) == 3 :
45+ qweight , weight_scale , input_scale = weights
46+ elif len (weights ) == 2 :
47+ qweight , weight_scale = weights
48+ x_q , x_scale , x_zp = ops .scaled_int8_quant (input_tensor , scale = input_scale , azp = None , symmetric = True )
4749 m = input_tensor .shape [0 ]
48- n = weights [ 0 ] .shape [1 ]
50+ n = qweight .shape [1 ]
4951 if out is None :
5052 out = g_cache_manager .alloc_tensor (
5153 (m , n ), input_tensor .dtype , device = input_tensor .device , is_graph_out = False
5254 )
53- torch .ops ._C .cutlass_scaled_mm (out , x_q , weights [ 0 ] , x_scale , weights [ 1 ] , bias )
55+ torch .ops ._C .cutlass_scaled_mm (out , x_q , qweight , x_scale , weight_scale , bias )
5456 return out
5557
5658
0 commit comments