Skip to content

Commit fff1ea6

Browse files
committed
updated
1 parent a8163bc commit fff1ea6

File tree

1 file changed

+13
-55
lines changed

1 file changed

+13
-55
lines changed

aiak_training_llm/models/llavaov_1_5/rice_vision_model.py

Lines changed: 13 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -183,82 +183,57 @@ def __init__(self,
183183
eps=1e-4)
184184

185185
def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
186-
# self.class_embedding.zero_()
187-
# self.class_pos_emb.zero_()
188-
# 应用patch嵌入
189-
# print_rank_0(self.pre_layernorm.weight)
190-
# print_rank_0(self.pre_layernorm.bias)
191-
192186
x = self.patch_embed(x)
193187

194-
# print_rank_0("In MegatronLM forward, hidden_states:")
195-
# print_rank_0(x)
196-
# 获取batch_size和序列长度
197188
batch_size = grid_thw.size(0)
198189
seq_len, hidden_dim = x.size()
199-
# 获取旋转位置编码
190+
200191
rotary_pos_emb = self.rot_pos_emb(grid_thw)
201-
# print_rank_0("In MegatronLM forward, rotary_pos_emb")
202-
# 创建class token和它的位置编码
203-
# 为每个batch创建一个class token
192+
204193
class_embedding = self.class_embedding.view(1, -1)
205194
class_pos_emb = self.class_pos_emb.view(1, -1)
206195
class_tokens = class_embedding.expand(batch_size, -1)
207196
class_pos_embs = class_pos_emb.expand(batch_size, -1)
208197

209-
# 计算每个样本在原始序列中的token数量
210198
tokens_per_sample = []
211199

212200
for i in range(batch_size):
213201
t, h, w = grid_thw[i]
214202
tokens_per_sample.append((t * h * w).item())
215203

216-
# 将class tokens插入到对应batch的token序列开头
217204
new_x = []
218205
start_idx = 0
219206
for i in range(batch_size):
220-
# 添加当前batch的class token
221207
new_x.append(class_tokens[i:i+1])
222-
# 添加当前batch的image tokens
223208
new_x.append(x[start_idx:start_idx+tokens_per_sample[i]])
224209
start_idx += tokens_per_sample[i]
225210

226-
# 将所有token连接成一个序列
227211
x = torch.cat(new_x, dim=0)
228212

229-
230-
# rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
231-
# rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
232-
233-
# 同样为位置编码添加class token的位置编码
234213
new_rotary_pos_emb = []
235214
start_idx = 0
236215
for i in range(batch_size):
237-
# 添加当前batch的class token位置编码
238216
new_rotary_pos_emb.append(class_pos_embs[i:i+1])
239-
# 添加当前batch的image token位置编码
240217
new_rotary_pos_emb.append(rotary_pos_emb[start_idx:start_idx+tokens_per_sample[i]])
241-
242218
start_idx += tokens_per_sample[i]
243-
# 将所有位置编码连接成一个序列
219+
244220
rotary_pos_emb = torch.cat(new_rotary_pos_emb, dim=0)
245-
# 更新cu_seqlens,每个batch需要考虑额外的class token
221+
246222
cu_seqlens = []
247223
cumulative_length = 0
248224
cu_seqlens.append(cumulative_length) # 起始为0
249225
for length in tokens_per_sample:
250-
# 每个batch的长度需要+1,因为添加了class token
226+
251227
cumulative_length += int(length + 1)
252228
cu_seqlens.append(cumulative_length)
253229

254-
# 转换为tensor
230+
255231
cu_seqlens = torch.tensor(
256232
cu_seqlens,
257233
device=x.device,
258234
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32
259235
)
260-
# print_rank_0(cu_seqlens)
261-
# print_rank_0(x.size())
236+
262237
x = x[:, None, :].contiguous() # [s, h] -> [s, 1, h]
263238

264239
x = self.pre_layernorm(x)
@@ -279,12 +254,8 @@ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
279254
patch_output = []
280255
start_idx = 0
281256
for i in range(batch_size):
282-
# 跳过class token
283257
start_idx += 1
284-
# 收集图像tokens
285258
patch_output.append(x[start_idx:start_idx+tokens_per_sample[i]])
286-
# print_rank_0(f"start: {start_idx}, end: {start_idx + tokens_per_sample[i]}, tokens_per_sample: {tokens_per_sample[i]}")
287-
# 移动到下一个batch的起始位置
288259
start_idx += tokens_per_sample[i]
289260
patch_output = torch.cat(patch_output, dim=0) # [原始seq_len, hidden_size]
290261
return patch_output, None
@@ -298,62 +269,53 @@ def forward_debug(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor
298269

299270
batch_size = grid_thw.size(0)
300271
seq_len, hidden_dim = x.size()
301-
# 获取旋转位置编码
302272
rotary_pos_emb = self.rot_pos_emb(grid_thw)
303273
class_embedding = self.class_embedding.view(1, -1)
304274
class_pos_emb = self.class_pos_emb.view(1, -1)
305275
class_tokens = class_embedding.expand(batch_size, -1)
306276
class_pos_embs = class_pos_emb.expand(batch_size, -1)
307277

308-
# 计算每个样本在原始序列中的token数量
309278
tokens_per_sample = []
310279

311280
for i in range(batch_size):
312281
t, h, w = grid_thw[i]
313282
tokens_per_sample.append((t * h * w).item())
314283

315-
# 将class tokens插入到对应batch的token序列开头
316284
new_x = []
317285
start_idx = 0
318286
for i in range(batch_size):
319-
# 添加当前batch的class token
287+
320288
new_x.append(class_tokens[i:i+1])
321-
# 添加当前batch的image tokens
289+
322290
new_x.append(x[start_idx:start_idx+tokens_per_sample[i]])
323291
start_idx += tokens_per_sample[i]
324292

325293
x = torch.cat(new_x, dim=0)
326-
327-
# 同样为位置编码添加class token的位置编码
328294
new_rotary_pos_emb = []
329295
start_idx = 0
330296
for i in range(batch_size):
331-
# 添加当前batch的class token位置编码
332297
new_rotary_pos_emb.append(class_pos_embs[i:i+1])
333-
# 添加当前batch的image token位置编码
334298
new_rotary_pos_emb.append(rotary_pos_emb[start_idx:start_idx+tokens_per_sample[i]])
335-
336299
start_idx += tokens_per_sample[i]
337-
# 将所有位置编码连接成一个序列
300+
338301
rotary_pos_emb = torch.cat(new_rotary_pos_emb, dim=0)
339302
output["rotary_pos_emb"] = rotary_pos_emb.clone()
340303
output["class_embedding"] = self.class_embedding.clone()
341304
cu_seqlens = []
342305
cumulative_length = 0
343306
cu_seqlens.append(cumulative_length) # 起始为0
344307
for length in tokens_per_sample:
345-
# 每个batch的长度需要+1,因为添加了class token
308+
346309
cumulative_length += int(length + 1)
347310
cu_seqlens.append(cumulative_length)
348311

349-
# 转换为tensor
312+
350313
cu_seqlens = torch.tensor(
351314
cu_seqlens,
352315
device=x.device,
353316
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32
354317
)
355-
# print_rank_0(cu_seqlens)
356-
# print_rank_0(x.size())
318+
357319
x = x[:, None, :].contiguous() # [s, h] -> [s, 1, h]
358320

359321
x = self.pre_layernorm(x)
@@ -374,12 +336,8 @@ def forward_debug(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor
374336
patch_output = []
375337
start_idx = 0
376338
for i in range(batch_size):
377-
# 跳过class token
378339
start_idx += 1
379-
# 收集图像tokens
380340
patch_output.append(x[start_idx:start_idx+tokens_per_sample[i]])
381-
# print_rank_0(f"start: {start_idx}, end: {start_idx + tokens_per_sample[i]}, tokens_per_sample: {tokens_per_sample[i]}")
382-
# 移动到下一个batch的起始位置
383341
start_idx += tokens_per_sample[i]
384342
patch_output = torch.cat(patch_output, dim=0) # [原始seq_len, hidden_size]
385343
output["before_adapter"] = patch_output.clone()

0 commit comments

Comments
 (0)