Skip to content

Commit 62603c9

Browse files
authored
Fix qwen2vl_mrope unit test (#728)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Confront transformers VLM config change <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Tcc0403 <[email protected]>
1 parent 5d25e46 commit 62603c9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/transformers/test_qwen2vl_mrope.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from test.utils import supports_bfloat16
55

66
try:
7-
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
7+
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLTextConfig
88
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding
99
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
1010

@@ -45,7 +45,7 @@
4545
],
4646
)
4747
def test_correctness(bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol):
48-
rotary_emb = Qwen2VLRotaryEmbedding(config=Qwen2VLConfig(head_dim=head_dim), device=device)
48+
rotary_emb = Qwen2VLRotaryEmbedding(config=Qwen2VLTextConfig(head_dim=head_dim), device=device)
4949

5050
_tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype)
5151

@@ -105,7 +105,7 @@ def test_functional_correctness(bsz, seq_len, num_q_heads, num_kv_heads, head_di
105105
k1 = _k.clone().requires_grad_(True)
106106
k2 = _k.clone().requires_grad_(True)
107107

108-
rotary_emb = Qwen2VLRotaryEmbedding(config=Qwen2VLConfig(head_dim=head_dim), device=device)
108+
rotary_emb = Qwen2VLRotaryEmbedding(config=Qwen2VLTextConfig(head_dim=head_dim), device=device)
109109

110110
pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(3, bsz, seq_len)
111111
cos, sin = rotary_emb(k1, pos_ids)

0 commit comments

Comments
 (0)