|
6 | 6 | from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm |
7 | 7 | from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops |
8 | 8 | from typing import Any |
9 | | -from typing import TYPE_CHECKING, Optional |
| 9 | +from typing import TYPE_CHECKING, Optional, Tuple |
| 10 | +from lightllm.utils.dist_utils import get_current_device_id |
10 | 11 |
|
11 | 12 | if TYPE_CHECKING: |
12 | 13 | from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack |
@@ -118,6 +119,27 @@ def method_name(self): |
118 | 119 | def quantize(self, weight: torch.Tensor): |
119 | 120 | raise NotImplementedError("AWQ online quantization is not supported yet.") |
120 | 121 |
|
| 122 | + def params_need_repack(self) -> bool: |
| 123 | + """ |
| 124 | + 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 |
| 125 | + """ |
| 126 | + return True |
| 127 | + |
| 128 | + def params_repack( |
| 129 | + self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype |
| 130 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 131 | + """ |
| 132 | + 一些量化方法在将参数完成量化后,为了加速性能,还需要将参数进行重拍,使算子性能达到最优,如awq方法。 |
| 133 | + """ |
| 134 | + weight = self._process_weight_after_loading(weight.cuda(get_current_device_id())) |
| 135 | + weight_scale = self._process_weight_scale_after_loading( |
| 136 | + weight_scale.cuda(get_current_device_id()).to(dtype_type) |
| 137 | + ) |
| 138 | + weight_zero_point = self._process_weight_zero_point_after_loading( |
| 139 | + weight_zero_point.cuda(get_current_device_id()) |
| 140 | + ) |
| 141 | + return weight, weight_scale, weight_zero_point |
| 142 | + |
121 | 143 | def _process_weight_after_loading(self, weight: torch.Tensor) -> torch.Tensor: |
122 | 144 | assert self.hf_quantization_config is not None, "hf_quantization_config is not set" |
123 | 145 | self.k = weight.shape[0] |
|
0 commit comments