44
55
66class MMWeightTpl (BaseWeightTpl ):
7- def __init__ (self , data_type , split_n_embed ):
7+ def __init__ (self , data_type ):
88 super ().__init__ ()
99 self .data_type_ = data_type
10- self .start = split_n_embed * self .tp_rank_
11- self .end = split_n_embed * (self .tp_rank_ + 1 )
1210 self .quant_method = None
1311 self .weight = None
1412 self .bias = None
@@ -40,7 +38,9 @@ def _post_load_weights(self):
4038
4139class MMWeight (MMWeightTpl ):
4240 def __init__ (self , weight_name , data_type , split_n_embed , bias_name = None ):
43- super ().__init__ (data_type , split_n_embed )
41+ super ().__init__ (data_type )
42+ self .start = split_n_embed * self .tp_rank_
43+ self .end = split_n_embed * (self .tp_rank_ + 1 )
4444 self .weight_name = weight_name
4545 self .bias_name = bias_name
4646
@@ -72,7 +72,7 @@ def load_hf_weights(self, weights):
7272 return
7373
7474
75- class ROWMMWeightNoTP (MMWeight ):
75+ class ROWMMWeightNoTP (ROWMMWeight ):
7676 def __init__ (self , weight_name , data_type , split_n_embed , bias_name = None ):
7777 super ().__init__ (weight_name , data_type , split_n_embed , bias_name )
7878 self .start = 0
@@ -98,13 +98,20 @@ def load_hf_weights(self, weights):
9898
9999
100100class MultiMMWeight (MMWeightTpl ):
101- def __init__ (self , weight_names , data_type , split_n_embed , bias_names = None ):
102- super ().__init__ (data_type , split_n_embed )
101+ def __init__ (self , weight_names , data_type , split_n_embeds , bias_names = []):
102+ super ().__init__ (data_type )
103+ if isinstance (split_n_embeds , int ):
104+ self .split_n_embeds = [split_n_embeds ] * len (weight_names )
105+ else :
106+ self .split_n_embeds = split_n_embeds
107+
108+ self .starts = [i * self .tp_rank_ for i in self .split_n_embeds ]
109+ self .ends = [i * (self .tp_rank_ + 1 ) for i in self .split_n_embeds ]
103110 self .weight_names = weight_names
104111 self .bias_names = bias_names
105112 self .weights = [None ] * len (self .weight_names )
106113 self .biases = [None ] * len (self .bias_names )
107- self .has_bias = all (b is not None for b in self .bias_names )
114+ self .has_bias = all (b is not None for b in self .bias_names ) and len ( bias_names ) > 0
108115
109116 def verify_load (self ):
110117 load_ok = True
@@ -117,7 +124,7 @@ def verify_load(self):
117124
118125
119126class MultiROWMMWeight (MultiMMWeight ):
120- def __init__ (self , weight_names , data_type , split_n_embed , bias_names = None ):
127+ def __init__ (self , weight_names , data_type , split_n_embed , bias_names = [] ):
121128 super ().__init__ (weight_names , data_type , split_n_embed , bias_names )
122129
123130 def _fuse (self ):
@@ -134,86 +141,48 @@ def load_hf_weights(self, weights):
134141 for i in range (len (self .weight_names )):
135142 if self .weight_names [i ] in weights :
136143 weight = weights [self .weight_names [i ]].to (self .data_type_ )
137- self .weights [i ] = weight [self .start : self .end ]
144+ self .weights [i ] = weight [self .starts [ i ] : self .ends [ i ] ]
138145 if self .has_bias and self .bias_names [i ] in weights :
139146 bias = weights [self .bias_names [i ]].to (self .data_type_ )
140- self .biases [i ] = bias [self .start : self .end ]
147+ self .biases [i ] = bias [self .starts [ i ] : self .ends [ i ] ]
141148 self ._fuse ()
142149 return
143150
144151
145152class MultiROWMMWeightNoTP (MultiROWMMWeight ):
146- def __init__ (self , weight_names , data_type , split_n_embed , bias_names = None ):
153+ def __init__ (self , weight_names , data_type , split_n_embed , bias_names = [] ):
147154 super ().__init__ (weight_names , data_type , split_n_embed , bias_names )
148- self .start = 0
149- self .end = split_n_embed
155+ self .starts = [ 0 for i in self . split_n_embeds ]
156+ self .ends = [ i for i in self . split_n_embeds ]
150157
151158
152- class CustomMMWeight (ROWMMWeight ):
153- def __init__ (
154- self ,
155- weight_name ,
156- data_type ,
157- split_n_embed ,
158- bias_name = None ,
159- wait_fuse = False ,
160- disable_tp = False ,
161- custom_load = None ,
162- custom_fuse = None ,
163- ):
164- super ().__init__ (weight_name , data_type , split_n_embed , bias_name , wait_fuse = wait_fuse , disable_tp = disable_tp )
165- self .custom_load = custom_load
166- self .custom_fuse = custom_fuse
167-
168- def fuse (self , B , op = None ):
169- if self .custom_fuse is None :
170- super ().fuse (B , op )
171- else :
172- weight = self .custom_fuse (self , B )
173- self .post_load_weights (weight )
159+ class MultiCOLMMWeight (MultiROWMMWeight ):
160+ def __init__ (self , weight_names , data_type , split_n_embed , bias_names = []):
161+ super ().__init__ (weight_names , data_type , split_n_embed , bias_names )
174162
175163 def load_hf_weights (self , weights ):
176- if self .custom_load is None :
177- super ().load_hf_weights (weights )
178- else :
179- weight = None
180- if self .weight_name in weights :
181- weight = self .custom_load (self , self .pre_load_weights (weights [self .weight_name ]))
182- if weight is None :
183- return
184- if self .wait_fuse :
185- self .weight = weight
186- return
187- self .post_load_weights (weight )
164+ weight = None
165+ for i in range (len (self .weight_names )):
166+ if self .weight_names [i ] in weights :
167+ weight = weights [self .weight_names [i ]].to (self .data_type_ )
168+ self .weights [i ] = weight [:, self .starts [i ] : self .ends [i ]]
169+ if self .has_bias and self .bias_names [i ] in weights :
170+ bias = weights [self .bias_names [i ]].to (self .data_type_ )
171+ self .biases [i ] = bias [:, self .starts [i ] : self .ends [i ]]
172+ self ._fuse ()
188173 return
189174
190175
191- class CustomBMMWeight (CustomMMWeight ):
192- def __init__ (
193- self ,
194- weight_name ,
195- data_type ,
196- split_n_embed ,
197- bias_name = None ,
198- wait_fuse = False ,
199- disable_tp = False ,
200- custom_load = None ,
201- custom_fuse = None ,
202- ):
203- super ().__init__ (
204- weight_name ,
205- data_type ,
206- split_n_embed ,
207- bias_name ,
208- wait_fuse = wait_fuse ,
209- disable_tp = disable_tp ,
210- custom_load = custom_load ,
211- custom_fuse = custom_fuse ,
212- )
176+ class BMMWeightTpl (BaseWeightTpl ):
177+ def __init__ (self , data_type ):
178+ super ().__init__ ()
179+ self .data_type_ = data_type
180+ self .quant_method = None
181+ self .weight = None
182+ self .bias = None
213183
214184 def set_quant_method (self , quant_method ):
215- return
216- raise NotImplementedError ("BMM does not currently support quantification" )
185+ self .quant_method = None
217186
218187 def bmm (self , input_tensor , out = None , use_custom_tensor_mananger = True ):
219188 if self .quant_method is not None :
@@ -230,8 +199,52 @@ def bmm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
230199 return torch .bmm (input_tensor , self .weight , out = out )
231200 return torch .addbmm (self .bias , input_tensor , self .weight , out = out )
232201
233- def post_load_weights (self , weight ):
234- if self .quant_method is not None :
235- self .weight = self .quant_method .quantize (weight .cuda (self .tp_rank_ ))
236- return
237- self .weight = weight .cuda (self .tp_rank_ )
202+ def _post_load_weights (self ):
203+ self .weight = self .weight .cuda (self .tp_rank_ )
204+
205+
206+ class BMMWeight (BMMWeightTpl ):
207+ def __init__ (self , weight_name , data_type , split_n_embed , bias_name = None ):
208+ super ().__init__ (data_type )
209+ self .start = split_n_embed * self .tp_rank_
210+ self .end = split_n_embed * (self .tp_rank_ + 1 )
211+ self .weight_name = weight_name
212+ self .bias_name = bias_name
213+
214+ def verify_load (self ):
215+ load_ok = True
216+ # Verify weight. The weight must be not None.
217+ load_ok = load_ok and self .weight is not None
218+ # Verify bias. If bias_name is set, it must be not None.
219+ if self .bias_name is not None :
220+ load_ok = load_ok and self .bias is not None
221+ return load_ok
222+
223+
224+ class ROWBMMWeight (BMMWeight ):
225+ load_hf_weights = ROWMMWeight .load_hf_weights
226+
227+ def __init__ (
228+ self ,
229+ weight_name ,
230+ data_type ,
231+ split_n_embed ,
232+ bias_name = None ,
233+ ):
234+ super ().__init__ (weight_name , data_type , split_n_embed , bias_name )
235+
236+
237+ class COLBMMWeight (BMMWeight ):
238+ load_hf_weights = COLMMWeight .load_hf_weights
239+
240+ def __init__ (
241+ self ,
242+ weight_name ,
243+ data_type ,
244+ split_n_embed ,
245+ bias_name = None ,
246+ ):
247+ super ().__init__ (weight_name , data_type , split_n_embed , bias_name )
248+
249+ def _post_load_weights (self ):
250+ self .weight = self .weight .transpose (0 , 1 ).cuda (self .tp_rank_ )
0 commit comments