Skip to content

Commit 5c2d7a0

Browse files
committed
Fix tests
1 parent 59cfdb7 commit 5c2d7a0

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/models/llava/test/test_llava.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_llava_export(self):
9797
)[0]
9898
llava_module.run_method(
9999
"text_decoder",
100-
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
100+
(pte_embeds_before_img, torch.tensor([start_pos], dtype=torch.int64)),
101101
)
102102

103103
# Update the start_pos. start_pos is used in kv cache. The source of truth
@@ -109,8 +109,8 @@ def test_llava_export(self):
109109
llava_module.run_method(
110110
"text_decoder",
111111
(
112-
torch.tensor([start_pos], dtype=torch.int64),
113112
pte_embeds_img,
113+
torch.tensor([start_pos], dtype=torch.int64),
114114
),
115115
)
116116

@@ -123,7 +123,7 @@ def test_llava_export(self):
123123
)[0]
124124
pte_prefill_after_img = llava_module.run_method(
125125
"text_decoder",
126-
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
126+
(pte_embeds_after_img, torch.tensor([start_pos], dtype=torch.int64)),
127127
)[0]
128128

129129
# Update the logits for each prefill (kv cache) step.
@@ -140,7 +140,7 @@ def test_llava_export(self):
140140
)[0]
141141
logits = llava_module.run_method(
142142
"text_decoder",
143-
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
143+
(token_embeds, torch.tensor([start_pos + i], dtype=torch.int64)),
144144
)[0]
145145
new_tokens.append(torch.argmax(logits).item())
146146

0 commit comments

Comments
 (0)