|
8 | 8 | from loguru import logger |
9 | 9 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
10 | 10 |
|
11 | | -from .quant import FloatQuantizer |
12 | 11 | from .utils import is_fp8_supported_gpu |
13 | 12 |
|
14 | 13 | if is_fp8_supported_gpu(): |
|
23 | 22 | 'Using LLMC Quantizer implementation instead.' |
24 | 23 | ) |
25 | 24 |
|
| 25 | +try: |
| 26 | + from vllm import _custom_ops as ops |
| 27 | +except ModuleNotFoundError: |
| 28 | + ops = None |
26 | 29 |
|
27 | 30 | try: |
28 | 31 | import fast_hadamard_transform |
|
34 | 37 | 'If you need it, please install it firstly.' |
35 | 38 | ) |
36 | 39 |
|
37 | | -from .utils import calculate_zeros_width |
38 | | - |
39 | 40 |
|
40 | 41 | def block_wise_fp8_forward_func(x, w, w_scale, block_size, bias): |
41 | 42 | x, scale = act_quant(x, block_size) |
@@ -127,13 +128,105 @@ def new(cls, module): |
127 | 128 | return new_module |
128 | 129 |
|
129 | 130 |
|
| 131 | +class VllmQuantLinearInt8(nn.Module): |
| 132 | + def __init__(self, in_features, out_features, bias=True): |
| 133 | + super().__init__() |
| 134 | + self.in_features = in_features |
| 135 | + self.out_features = out_features |
| 136 | + |
| 137 | + self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.int8)) |
| 138 | + self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32)) |
| 139 | + |
| 140 | + if bias: |
| 141 | + self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16)) |
| 142 | + else: |
| 143 | + self.register_buffer('bias', None) |
| 144 | + |
| 145 | + def act_quant_func(self, x): |
| 146 | + input_tensor_quant, input_tensor_scale, _ \ |
| 147 | + = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) |
| 148 | + return input_tensor_quant, input_tensor_scale |
| 149 | + |
| 150 | + def forward(self, input_tensor): |
| 151 | + input_tensor = input_tensor.squeeze(0) |
| 152 | + shape = (input_tensor.shape[0], self.weight.shape[0]) |
| 153 | + dtype = input_tensor.dtype |
| 154 | + device = input_tensor.device |
| 155 | + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) |
| 156 | + |
| 157 | + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) |
| 158 | + torch.ops._C.cutlass_scaled_mm( |
| 159 | + output_tensor, |
| 160 | + input_tensor_quant, |
| 161 | + self.weight.t(), |
| 162 | + input_tensor_scale, |
| 163 | + self.weight_scale.float(), |
| 164 | + self.bias, |
| 165 | + ) |
| 166 | + return output_tensor.unsqueeze(0) |
| 167 | + |
| 168 | + @classmethod |
| 169 | + @torch.no_grad() |
| 170 | + def new(cls, module): |
| 171 | + in_features = module.in_features |
| 172 | + out_features = module.out_features |
| 173 | + bias = module.bias is not None |
| 174 | + new_module = cls(in_features, out_features, bias) |
| 175 | + return new_module |
| 176 | + |
| 177 | + |
| 178 | +class VllmQuantLinearFp8(nn.Module): |
| 179 | + def __init__(self, in_features, out_features, bias=True): |
| 180 | + super().__init__() |
| 181 | + self.in_features = in_features |
| 182 | + self.out_features = out_features |
| 183 | + self.register_buffer('weight', torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn)) # noqa |
| 184 | + self.register_buffer('weight_scale', torch.empty((out_features, 1), dtype=torch.float32)) |
| 185 | + if bias: |
| 186 | + self.register_buffer('bias', torch.empty(out_features, dtype=torch.bfloat16)) |
| 187 | + else: |
| 188 | + self.register_buffer('bias', None) |
| 189 | + |
| 190 | + def act_quant_func(self, x): |
| 191 | + input_tensor_quant, input_tensor_scale \ |
| 192 | + = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) |
| 193 | + return input_tensor_quant, input_tensor_scale |
| 194 | + |
| 195 | + def forward(self, input_tensor): |
| 196 | + input_tensor = input_tensor.squeeze(0) |
| 197 | + shape = (input_tensor.shape[0], self.weight.shape[0]) |
| 198 | + dtype = input_tensor.dtype |
| 199 | + device = input_tensor.device |
| 200 | + output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) |
| 201 | + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) |
| 202 | + torch.ops._C.cutlass_scaled_mm( |
| 203 | + output_tensor, |
| 204 | + input_tensor_quant, |
| 205 | + self.weight.t(), |
| 206 | + input_tensor_scale, |
| 207 | + self.weight_scale.float(), |
| 208 | + self.bias, |
| 209 | + ) |
| 210 | + |
| 211 | + return output_tensor.unsqueeze(0) |
| 212 | + |
| 213 | + @classmethod |
| 214 | + @torch.no_grad() |
| 215 | + def new(cls, module): |
| 216 | + in_features = module.in_features |
| 217 | + out_features = module.out_features |
| 218 | + bias = module.bias is not None |
| 219 | + new_module = cls(in_features, out_features, bias) |
| 220 | + return new_module |
| 221 | + |
| 222 | + |
130 | 223 | class LlmcFp8Linear(nn.Module): |
131 | 224 | def __init__(self, in_features, out_features, bias, block_size): |
132 | 225 | super().__init__() |
133 | 226 | self.block_size = block_size |
134 | 227 | self.in_features = in_features |
135 | 228 | self.out_features = out_features |
136 | | - if bias is not None: |
| 229 | + if bias: |
137 | 230 | self.bias = nn.Parameter(torch.empty(out_features)) |
138 | 231 | else: |
139 | 232 | self.register_parameter('bias', None) |
@@ -172,7 +265,7 @@ def forward(self, x): |
172 | 265 | def new(cls, module, block_size): |
173 | 266 | in_features = module.in_features |
174 | 267 | out_features = module.out_features |
175 | | - bias = module.bias |
| 268 | + bias = module.bias is not None |
176 | 269 | new_module = cls(in_features, out_features, bias, block_size) |
177 | 270 | return new_module |
178 | 271 |
|
|
0 commit comments