@@ -183,82 +183,57 @@ def __init__(self,
183183 eps = 1e-4 )
184184
185185 def forward (self , x : torch .Tensor , grid_thw : torch .Tensor ) -> torch .Tensor :
186- # self.class_embedding.zero_()
187- # self.class_pos_emb.zero_()
188- # 应用patch嵌入
189- # print_rank_0(self.pre_layernorm.weight)
190- # print_rank_0(self.pre_layernorm.bias)
191-
192186 x = self .patch_embed (x )
193187
194- # print_rank_0("In MegatronLM forward, hidden_states:")
195- # print_rank_0(x)
196- # 获取batch_size和序列长度
197188 batch_size = grid_thw .size (0 )
198189 seq_len , hidden_dim = x .size ()
199- # 获取旋转位置编码
190+
200191 rotary_pos_emb = self .rot_pos_emb (grid_thw )
201- # print_rank_0("In MegatronLM forward, rotary_pos_emb")
202- # 创建class token和它的位置编码
203- # 为每个batch创建一个class token
192+
204193 class_embedding = self .class_embedding .view (1 , - 1 )
205194 class_pos_emb = self .class_pos_emb .view (1 , - 1 )
206195 class_tokens = class_embedding .expand (batch_size , - 1 )
207196 class_pos_embs = class_pos_emb .expand (batch_size , - 1 )
208197
209- # 计算每个样本在原始序列中的token数量
210198 tokens_per_sample = []
211199
212200 for i in range (batch_size ):
213201 t , h , w = grid_thw [i ]
214202 tokens_per_sample .append ((t * h * w ).item ())
215203
216- # 将class tokens插入到对应batch的token序列开头
217204 new_x = []
218205 start_idx = 0
219206 for i in range (batch_size ):
220- # 添加当前batch的class token
221207 new_x .append (class_tokens [i :i + 1 ])
222- # 添加当前batch的image tokens
223208 new_x .append (x [start_idx :start_idx + tokens_per_sample [i ]])
224209 start_idx += tokens_per_sample [i ]
225210
226- # 将所有token连接成一个序列
227211 x = torch .cat (new_x , dim = 0 )
228212
229-
230- # rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
231- # rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
232-
233- # 同样为位置编码添加class token的位置编码
234213 new_rotary_pos_emb = []
235214 start_idx = 0
236215 for i in range (batch_size ):
237- # 添加当前batch的class token位置编码
238216 new_rotary_pos_emb .append (class_pos_embs [i :i + 1 ])
239- # 添加当前batch的image token位置编码
240217 new_rotary_pos_emb .append (rotary_pos_emb [start_idx :start_idx + tokens_per_sample [i ]])
241-
242218 start_idx += tokens_per_sample [i ]
243- # 将所有位置编码连接成一个序列
219+
244220 rotary_pos_emb = torch .cat (new_rotary_pos_emb , dim = 0 )
245- # 更新cu_seqlens,每个batch需要考虑额外的class token
221+
246222 cu_seqlens = []
247223 cumulative_length = 0
248224 cu_seqlens .append (cumulative_length ) # 起始为0
249225 for length in tokens_per_sample :
250- # 每个batch的长度需要+1,因为添加了class token
226+
251227 cumulative_length += int (length + 1 )
252228 cu_seqlens .append (cumulative_length )
253229
254- # 转换为tensor
230+
255231 cu_seqlens = torch .tensor (
256232 cu_seqlens ,
257233 device = x .device ,
258234 dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32
259235 )
260- # print_rank_0(cu_seqlens)
261- # print_rank_0(x.size())
236+
262237 x = x [:, None , :].contiguous () # [s, h] -> [s, 1, h]
263238
264239 x = self .pre_layernorm (x )
@@ -279,12 +254,8 @@ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
279254 patch_output = []
280255 start_idx = 0
281256 for i in range (batch_size ):
282- # 跳过class token
283257 start_idx += 1
284- # 收集图像tokens
285258 patch_output .append (x [start_idx :start_idx + tokens_per_sample [i ]])
286- # print_rank_0(f"start: {start_idx}, end: {start_idx + tokens_per_sample[i]}, tokens_per_sample: {tokens_per_sample[i]}")
287- # 移动到下一个batch的起始位置
288259 start_idx += tokens_per_sample [i ]
289260 patch_output = torch .cat (patch_output , dim = 0 ) # [原始seq_len, hidden_size]
290261 return patch_output , None
@@ -298,62 +269,53 @@ def forward_debug(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor
298269
299270 batch_size = grid_thw .size (0 )
300271 seq_len , hidden_dim = x .size ()
301- # 获取旋转位置编码
302272 rotary_pos_emb = self .rot_pos_emb (grid_thw )
303273 class_embedding = self .class_embedding .view (1 , - 1 )
304274 class_pos_emb = self .class_pos_emb .view (1 , - 1 )
305275 class_tokens = class_embedding .expand (batch_size , - 1 )
306276 class_pos_embs = class_pos_emb .expand (batch_size , - 1 )
307277
308- # 计算每个样本在原始序列中的token数量
309278 tokens_per_sample = []
310279
311280 for i in range (batch_size ):
312281 t , h , w = grid_thw [i ]
313282 tokens_per_sample .append ((t * h * w ).item ())
314283
315- # 将class tokens插入到对应batch的token序列开头
316284 new_x = []
317285 start_idx = 0
318286 for i in range (batch_size ):
319- # 添加当前batch的class token
287+
320288 new_x .append (class_tokens [i :i + 1 ])
321- # 添加当前batch的image tokens
289+
322290 new_x .append (x [start_idx :start_idx + tokens_per_sample [i ]])
323291 start_idx += tokens_per_sample [i ]
324292
325293 x = torch .cat (new_x , dim = 0 )
326-
327- # 同样为位置编码添加class token的位置编码
328294 new_rotary_pos_emb = []
329295 start_idx = 0
330296 for i in range (batch_size ):
331- # 添加当前batch的class token位置编码
332297 new_rotary_pos_emb .append (class_pos_embs [i :i + 1 ])
333- # 添加当前batch的image token位置编码
334298 new_rotary_pos_emb .append (rotary_pos_emb [start_idx :start_idx + tokens_per_sample [i ]])
335-
336299 start_idx += tokens_per_sample [i ]
337- # 将所有位置编码连接成一个序列
300+
338301 rotary_pos_emb = torch .cat (new_rotary_pos_emb , dim = 0 )
339302 output ["rotary_pos_emb" ] = rotary_pos_emb .clone ()
340303 output ["class_embedding" ] = self .class_embedding .clone ()
341304 cu_seqlens = []
342305 cumulative_length = 0
343306 cu_seqlens .append (cumulative_length ) # 起始为0
344307 for length in tokens_per_sample :
345- # 每个batch的长度需要+1,因为添加了class token
308+
346309 cumulative_length += int (length + 1 )
347310 cu_seqlens .append (cumulative_length )
348311
349- # 转换为tensor
312+
350313 cu_seqlens = torch .tensor (
351314 cu_seqlens ,
352315 device = x .device ,
353316 dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32
354317 )
355- # print_rank_0(cu_seqlens)
356- # print_rank_0(x.size())
318+
357319 x = x [:, None , :].contiguous () # [s, h] -> [s, 1, h]
358320
359321 x = self .pre_layernorm (x )
@@ -374,12 +336,8 @@ def forward_debug(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor
374336 patch_output = []
375337 start_idx = 0
376338 for i in range (batch_size ):
377- # 跳过class token
378339 start_idx += 1
379- # 收集图像tokens
380340 patch_output .append (x [start_idx :start_idx + tokens_per_sample [i ]])
381- # print_rank_0(f"start: {start_idx}, end: {start_idx + tokens_per_sample[i]}, tokens_per_sample: {tokens_per_sample[i]}")
382- # 移动到下一个batch的起始位置
383341 start_idx += tokens_per_sample [i ]
384342 patch_output = torch .cat (patch_output , dim = 0 ) # [原始seq_len, hidden_size]
385343 output ["before_adapter" ] = patch_output .clone ()
0 commit comments