Skip to content

Commit 565245f

Browse files
Update qwen2_5_vl.py
1 parent 349f7e5 commit 565245f

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,31 @@ def dtype(self) -> torch.dtype:
680680
def device(self) -> torch.device:
681681
return self.patch_embed.proj.weight.device
682682

683+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
684+
pos_ids = []
685+
for t, h, w in grid_thw:
686+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
687+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
688+
hpos_ids = hpos_ids.reshape(
689+
h // self.spatial_merge_size,
690+
self.spatial_merge_size,
691+
w // self.spatial_merge_size,
692+
self.spatial_merge_size,
693+
).permute(0, 2, 1, 3).flatten()
694+
wpos_ids = wpos_ids.reshape(
695+
h // self.spatial_merge_size,
696+
self.spatial_merge_size,
697+
w // self.spatial_merge_size,
698+
self.spatial_merge_size,
699+
).permute(0, 2, 1, 3).flatten()
700+
pos_ids.append(
701+
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
702+
pos_ids = torch.cat(pos_ids, dim=0)
703+
max_grid_size = grid_thw[:, 1:].max()
704+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
705+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
706+
return rotary_pos_emb
707+
683708
def rotary_pos_emb_thw(self, t, h, w):
684709
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
685710
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)

0 commit comments

Comments
 (0)