@@ -66,7 +66,7 @@ def _slice_weight(self, weight: torch.Tensor):
6666
6767 def _slice_bias (self , bias ):
6868 tp_size = bias .shape [0 ] // self .tp_world_size_
69- return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )].to (self .data_type_ )
69+ return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )] / self . tp_world_size_ .to (self .data_type_ )
7070
7171
7272class W8A8B128ROWMMWeight (UnquantizedROWMMWeight ):
@@ -93,13 +93,13 @@ def _slice_weight(self, weight: torch.Tensor):
9393
9494 def _slice_bias (self , bias ):
9595 tp_size = bias .shape [0 ] // self .tp_world_size_
96- return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
96+ return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )] / self . tp_world_size_
9797
9898 def _slice_weight_scale (self , weight_scale : torch .Tensor ):
9999 tp_size = weight_scale .shape [0 ] // self .tp_world_size_
100100 scale_start = tp_size * self .tp_rank_
101101 scale_end = tp_size * (self .tp_rank_ + 1 )
102- return weight_scale .to (torch .float )[scale_start : scale_end ]
102+ return weight_scale .to (torch .float )[scale_start : scale_end ]
103103
104104 def _process_weight_scale (self , weight_scale ) -> None :
105105 self .weight_scale = weight_scale .cuda (get_current_device_id ()).transpose (0 , 1 )
@@ -112,15 +112,16 @@ def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
112112 weight_scale = weights [self .weight_scale_name ]
113113 weight_scale = self ._slice_weight_scale (weight_scale )
114114 self ._process_weight_scale (weight_scale )
115-
115+
116116 if self .weight_scale is not None and isinstance (self .weight , torch .Tensor ):
117117 self .weight = [
118118 self .weight ,
119119 self .weight_scale ,
120- None , # placeholder for input scale
120+ None , # placeholder for input scale
121121 ]
122122 return
123123
124+
124125class UnquantizedMultiROWMMWeight (MultiMMWeightTpl ):
125126 _slice_weight = UnquantizedROWMMWeight ._slice_weight
126127 _slice_bias = UnquantizedROWMMWeight ._slice_bias
@@ -156,7 +157,9 @@ def __init__(
156157 self .weight_scale : Optional [torch .Tensor ] = None
157158 self .weight_scales = [None ] * len (self .weight_names )
158159 for weight_name in weight_names :
159- weight_scale_name , act_scale_name = generate_scale_name (weight_name , quant_method .weight_scale_suffix , quant_method .act_scale_suffix )
160+ weight_scale_name , act_scale_name = generate_scale_name (
161+ weight_name , quant_method .weight_scale_suffix , quant_method .act_scale_suffix
162+ )
160163 self .weight_scale_names .append (weight_scale_name )
161164 self .quantized_weight = True
162165
@@ -244,14 +247,14 @@ def _slice_weight_scale(self, weight_scale: torch.Tensor):
244247 tp_size = weight_scale .shape [0 ] // self .tp_world_size_
245248 scale_start = tp_size * self .tp_rank_
246249 scale_end = tp_size * (self .tp_rank_ + 1 )
247- return weight_scale [scale_start : scale_end ].to (torch .float )
250+ return weight_scale [scale_start : scale_end ].to (torch .float )
248251
249252 def _load_scales (self , weights : Dict [str , torch .Tensor ]) -> None :
250253 if self .weight_scale_name is not None and self .weight_scale_name in weights :
251254 weight_scale = weights [self .weight_scale_name ]
252255 weight_scale = self ._slice_weight_scale (weight_scale )
253256
254- if self .weight_name in weights and self .weight_scale is not None :
257+ if self .weight_scale is not None and isinstance ( self .weight , torch . Tensor ) :
255258 self .weight = [
256259 self .weight ,
257260 self .weight_scale ,
0 commit comments