@@ -66,7 +66,7 @@ def _slice_weight(self, weight: torch.Tensor):
6666
6767 def _slice_bias (self , bias ):
6868 tp_size = bias .shape [0 ] // self .tp_world_size_
69- return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )].to (self .data_type_ )
69+ return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )] / self . tp_world_size_ .to (self .data_type_ )
7070
7171
7272class W8A8B128ROWMMWeight (UnquantizedROWMMWeight ):
@@ -98,7 +98,7 @@ def _slice_weight(self, weight: torch.Tensor):
9898
9999 def _slice_bias (self , bias ):
100100 tp_size = bias .shape [0 ] // self .tp_world_size_
101- return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
101+ return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )] / self . tp_world_size_
102102
103103 def _slice_weight_scale (self , weight_scale : torch .Tensor ):
104104 scale_start = (self .weight_tp_size * self .tp_rank_ + self .block_size - 1 ) // self .block_size
@@ -114,8 +114,7 @@ def _post_process_weight_scale(self, weight_scale) -> None:
114114 def _post_process_weight (self , weight ) -> None :
115115 self .weight = weight .cuda (get_current_device_id ()).transpose (0 , 1 )
116116
117- def _load_weights (self , weights : Dict [str , torch .Tensor ]) -> None :
118- super ()._load_weights (weights )
117+ def _load_scales (self , weights : Dict [str , torch .Tensor ]) -> None :
119118 if self .weight_scale_name is not None and self .weight_scale_name in weights :
120119 weight_scale = weights [self .weight_scale_name ]
121120 # per channel or block-wise
@@ -296,8 +295,7 @@ def _slice_weight_scale(self, weight_scale: torch.Tensor):
296295 scale_end = self .weight_tp_size * (self .tp_rank_ + 1 )
297296 return weight_scale [scale_start : scale_end ].to (torch .float )
298297
299- def _load_weights (self , weights : Dict [str , torch .Tensor ]) -> None :
300- super ()._load_weights (weights )
298+ def _load_scales (self , weights : Dict [str , torch .Tensor ]) -> None :
301299 if self .weight_scale_name is not None and self .weight_scale_name in weights :
302300 weight_scale = weights [self .weight_scale_name ]
303301 # per channel or block-wise
0 commit comments