Skip to content

Commit 7b36448

Browse files
author
niushengxiao
committed
fix: continue fix
1 parent bfb50d6 commit 7b36448

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ def _slice_weight(self, tensor):
3939
self.weight_tp_size = tp_size
4040
return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_)
4141

42-
def _slice_bias(self, bias):
43-
tp_size = bias.shape[0] // self.tp_world_size_
44-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_)
45-
4642

4743
class W8A8B128COLMMWeight(UnquantizedCOLMMWeight):
4844
def __init__(
@@ -85,8 +81,7 @@ def _post_process_weight_scale(self, weight_scale) -> None:
8581
def _post_process_weight(self, weight) -> None:
8682
self.weight = weight.cuda(get_current_device_id()).transpose(0, 1)
8783

88-
def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
89-
super()._load_weights(weights)
84+
def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
9085
if self.weight_scale_name is not None and self.weight_scale_name in weights:
9186
weight_scale = weights[self.weight_scale_name]
9287
# per channel or block-wise

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _slice_weight(self, weight: torch.Tensor):
6666

6767
def _slice_bias(self, bias):
6868
tp_size = bias.shape[0] // self.tp_world_size_
69-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_)
69+
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_.to(self.data_type_)
7070

7171

7272
class W8A8B128ROWMMWeight(UnquantizedROWMMWeight):
@@ -98,7 +98,7 @@ def _slice_weight(self, weight: torch.Tensor):
9898

9999
def _slice_bias(self, bias):
100100
tp_size = bias.shape[0] // self.tp_world_size_
101-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
101+
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_
102102

103103
def _slice_weight_scale(self, weight_scale: torch.Tensor):
104104
scale_start = (self.weight_tp_size * self.tp_rank_ + self.block_size - 1) // self.block_size
@@ -114,8 +114,7 @@ def _post_process_weight_scale(self, weight_scale) -> None:
114114
def _post_process_weight(self, weight) -> None:
115115
self.weight = weight.cuda(get_current_device_id()).transpose(0, 1)
116116

117-
def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
118-
super()._load_weights(weights)
117+
def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
119118
if self.weight_scale_name is not None and self.weight_scale_name in weights:
120119
weight_scale = weights[self.weight_scale_name]
121120
# per channel or block-wise
@@ -296,8 +295,7 @@ def _slice_weight_scale(self, weight_scale: torch.Tensor):
296295
scale_end = self.weight_tp_size * (self.tp_rank_ + 1)
297296
return weight_scale[scale_start : scale_end].to(torch.float)
298297

299-
def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
300-
super()._load_weights(weights)
298+
def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
301299
if self.weight_scale_name is not None and self.weight_scale_name in weights:
302300
weight_scale = weights[self.weight_scale_name]
303301
# per channel or block-wise

0 commit comments

Comments
 (0)