Skip to content

Commit d31b122

Browse files
committed
merge
2 parents 7ea3bb4 + 7b36448 commit d31b122

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
tp_world_size: int = None,
5555
) -> None:
5656
super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size)
57-
57+
5858
self.weight_scale_name, self.act_scale_name = generate_scale_name(
5959
weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix
6060
)
@@ -68,9 +68,9 @@ def _slice_weight(self, tensor):
6868

6969
def _slice_weight_scale(self, weight_scale: torch.Tensor):
7070
tp_size = weight_scale.shape[1] // self.tp_world_size_
71-
scale_start = tp_size * self.tp_rank_
71+
scale_start = tp_size * self.tp_rank_
7272
scale_end = tp_size * (self.tp_rank_ + 1)
73-
return weight_scale[:, scale_start: scale_end].to(torch.float)
73+
return weight_scale[:, scale_start:scale_end].to(torch.float)
7474

7575
def _process_weight_scale(self, weight_scale) -> None:
7676
self.weight_scale = weight_scale.transpose(0, 1).cuda(get_current_device_id())

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _slice_weight(self, weight: torch.Tensor):
6666

6767
def _slice_bias(self, bias):
6868
tp_size = bias.shape[0] // self.tp_world_size_
69-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_)
69+
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_.to(self.data_type_)
7070

7171

7272
class W8A8B128ROWMMWeight(UnquantizedROWMMWeight):
@@ -93,13 +93,13 @@ def _slice_weight(self, weight: torch.Tensor):
9393

9494
def _slice_bias(self, bias):
9595
tp_size = bias.shape[0] // self.tp_world_size_
96-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
96+
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_
9797

9898
def _slice_weight_scale(self, weight_scale: torch.Tensor):
9999
tp_size = weight_scale.shape[0] // self.tp_world_size_
100100
scale_start = tp_size * self.tp_rank_
101101
scale_end = tp_size * (self.tp_rank_ + 1)
102-
return weight_scale.to(torch.float)[scale_start : scale_end]
102+
return weight_scale.to(torch.float)[scale_start:scale_end]
103103

104104
def _process_weight_scale(self, weight_scale) -> None:
105105
self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1)
@@ -112,15 +112,16 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
112112
weight_scale = weights[self.weight_scale_name]
113113
weight_scale = self._slice_weight_scale(weight_scale)
114114
self._process_weight_scale(weight_scale)
115-
115+
116116
if self.weight_scale is not None and isinstance(self.weight, torch.Tensor):
117117
self.weight = [
118118
self.weight,
119119
self.weight_scale,
120-
None, # placeholder for input scale
120+
None, # placeholder for input scale
121121
]
122122
return
123123

124+
124125
class UnquantizedMultiROWMMWeight(MultiMMWeightTpl):
125126
_slice_weight = UnquantizedROWMMWeight._slice_weight
126127
_slice_bias = UnquantizedROWMMWeight._slice_bias
@@ -156,7 +157,9 @@ def __init__(
156157
self.weight_scale: Optional[torch.Tensor] = None
157158
self.weight_scales = [None] * len(self.weight_names)
158159
for weight_name in weight_names:
159-
weight_scale_name, act_scale_name = generate_scale_name(weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix)
160+
weight_scale_name, act_scale_name = generate_scale_name(
161+
weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix
162+
)
160163
self.weight_scale_names.append(weight_scale_name)
161164
self.quantized_weight = True
162165

@@ -244,14 +247,14 @@ def _slice_weight_scale(self, weight_scale: torch.Tensor):
244247
tp_size = weight_scale.shape[0] // self.tp_world_size_
245248
scale_start = tp_size * self.tp_rank_
246249
scale_end = tp_size * (self.tp_rank_ + 1)
247-
return weight_scale[scale_start : scale_end].to(torch.float)
250+
return weight_scale[scale_start:scale_end].to(torch.float)
248251

249252
def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
250253
if self.weight_scale_name is not None and self.weight_scale_name in weights:
251254
weight_scale = weights[self.weight_scale_name]
252255
weight_scale = self._slice_weight_scale(weight_scale)
253256

254-
if self.weight_name in weights and self.weight_scale is not None:
257+
if self.weight_scale is not None and isinstance(self.weight, torch.Tensor):
255258
self.weight = [
256259
self.weight,
257260
self.weight_scale,

0 commit comments

Comments
 (0)