Skip to content

Commit 1b2ac13

Browse files
committed
fix set quantization
1 parent 267b540 commit 1b2ac13

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base_weight import BaseWeight
22
from .mm_weight import (
3+
MMWeightTpl,
34
MMWeight,
45
MultiMMWeight,
56
ROWMMWeight,

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def mm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
3030
return torch.addmm(self.bias, input_tensor, self.weight, out=out)
3131

3232
def _post_load_weights(self):
33+
print(self.quant_method)
3334
if self.quant_method is not None:
3435
self.weight = self.quant_method.quantize(self.weight.cuda(self.tp_rank_))
3536
return

lightllm/common/basemodel/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# from lightllm.common.layers.mm import MM
44
from .base_layer_weight import BaseLayerWeight
5-
from .meta_weights import BaseWeight, MultiMMWeight, MMWeight, FusedMoeWeight
5+
from .meta_weights import BaseWeight, MultiMMWeight, MMWeightTpl, FusedMoeWeight
66
from lightllm.utils.log_utils import init_logger
77

88
logger = init_logger(__name__)
@@ -51,7 +51,7 @@ def set_quantization(self):
5151
mix_quant_list = self.quant_cfg.get_mixed_list(self.layer_num_)
5252
for attr_name in dir(self):
5353
attr = getattr(self, attr_name)
54-
if isinstance(attr, MMWeight) or isinstance(attr, FusedMoeWeight):
54+
if isinstance(attr, MMWeightTpl) or isinstance(attr, FusedMoeWeight):
5555
if attr_name in mix_quant_list:
5656
attr.set_quant_method(self.quant_cfg.get_quant_method(self.layer_num_, attr_name))
5757
attr_quant_type = self.quant_cfg.get_quant_type(self.layer_num_, attr_name)

0 commit comments

Comments
 (0)