77class SliceMixinBase (ABC ):
88 """切片操作的Mixin基类"""
99
10- def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
10+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = False ):
1111 self .tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp ()
1212 self .tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size ()
13+ self .bias_div_world_size_ = bias_div_world_size
1314
1415 @abstractmethod
1516 def _slice_weight (self , weight : torch .Tensor ):
@@ -21,8 +22,8 @@ def _slice_bias(self, bias):
2122
2223
2324class SliceMixinTpl (SliceMixinBase ):
24- def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
25- super ().__init__ (tp_rank , tp_world_size )
25+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = False ):
26+ super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
2627
2728 def _slice_weight (self , weight : torch .Tensor ) -> torch .Tensor :
2829 raise NotImplementedError ("slice_weight must implement this method" )
@@ -40,8 +41,8 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
4041# 默认weight 的shape是 outxin,这也是目前最通用的约定。
4142# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。
4243class RowSliceMixin (SliceMixinTpl ):
43- def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
44- super ().__init__ (tp_rank , tp_world_size )
44+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = False ):
45+ super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
4546
4647 def _slice_weight (self , weight : torch .Tensor ) -> torch .Tensor :
4748 assert weight .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { weight .shape [0 ]} % { self .tp_world_size_ } "
@@ -51,14 +52,16 @@ def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
5152 def _slice_bias (self , bias ) -> torch .Tensor :
5253 assert bias .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { bias .shape [0 ]} % { self .tp_world_size_ } "
5354 tp_size = bias .shape [0 ] // self .tp_world_size_
55+ if self .bias_div_world_size_ :
56+ return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )] / self .tp_world_size_
5457 return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
5558
5659
5760# 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。
5861# 后续按需要,扩展per-tensor、per-channel的量化方式。
5962class QuantizedRowSliceMixin (RowSliceMixin ):
60- def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
61- super ().__init__ (tp_rank , tp_world_size )
63+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = False ):
64+ super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
6265
6366 def _slice_weight_scale (self , weight_scale : torch .Tensor ) -> torch .Tensor :
6467 assert (
@@ -80,8 +83,8 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
8083
8184
8285class ColSliceMixin (SliceMixinTpl ):
83- def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
84- super ().__init__ (tp_rank , tp_world_size )
86+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = True ):
87+ super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
8588
8689 def _slice_weight (self , weight : torch .Tensor ) -> torch .Tensor :
8790 assert weight .shape [1 ] % self .tp_world_size_ == 0 , f"tp slice error { weight .shape [1 ]} % { self .tp_world_size_ } "
@@ -91,12 +94,14 @@ def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
9194 def _slice_bias (self , bias ) -> torch .Tensor :
9295 assert bias .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { bias .shape [0 ]} % { self .tp_world_size_ } "
9396 tp_size = bias .shape [0 ] // self .tp_world_size_
97+ if self .bias_div_world_size_ :
98+ return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )] / self .tp_world_size_
9499 return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
95100
96101
97102class QuantizedColSliceMixin (ColSliceMixin ):
98- def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
99- super ().__init__ (tp_rank , tp_world_size )
103+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = True ):
104+ super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
100105
101106 def _slice_weight_scale (self , weight_scale : torch .Tensor ) -> torch .Tensor :
102107 assert (
0 commit comments