@@ -97,7 +97,7 @@ def test_llava_export(self):
97
97
)[0 ]
98
98
llava_module .run_method (
99
99
"text_decoder" ,
100
- (torch .tensor ([start_pos ], dtype = torch .int64 ), pte_embeds_before_img ),
100
+ (pte_embeds_before_img , torch .tensor ([start_pos ], dtype = torch .int64 )),
101
101
)
102
102
103
103
# Update the start_pos. start_pos is used in kv cache. The source of truth
@@ -109,8 +109,8 @@ def test_llava_export(self):
109
109
llava_module .run_method (
110
110
"text_decoder" ,
111
111
(
112
- torch .tensor ([start_pos ], dtype = torch .int64 ),
113
112
pte_embeds_img ,
113
+ torch .tensor ([start_pos ], dtype = torch .int64 ),
114
114
),
115
115
)
116
116
@@ -123,7 +123,7 @@ def test_llava_export(self):
123
123
)[0 ]
124
124
pte_prefill_after_img = llava_module .run_method (
125
125
"text_decoder" ,
126
- (torch .tensor ([start_pos ], dtype = torch .int64 ), pte_embeds_after_img ),
126
+ (pte_embeds_after_img , torch .tensor ([start_pos ], dtype = torch .int64 )),
127
127
)[0 ]
128
128
129
129
# Update the logits for each prefill (kv cache) step.
@@ -140,7 +140,7 @@ def test_llava_export(self):
140
140
)[0 ]
141
141
logits = llava_module .run_method (
142
142
"text_decoder" ,
143
- (torch .tensor ([start_pos + i ], dtype = torch .int64 ), token_embeds ),
143
+ (token_embeds , torch .tensor ([start_pos + i ], dtype = torch .int64 )),
144
144
)[0 ]
145
145
new_tokens .append (torch .argmax (logits ).item ())
146
146
0 commit comments