@@ -97,7 +97,7 @@ def test_llava_export(self):
9797 )[0 ]
9898 llava_module .run_method (
9999 "text_decoder" ,
100- (torch .tensor ([start_pos ], dtype = torch .int64 ), pte_embeds_before_img ),
100+ (pte_embeds_before_img , torch .tensor ([start_pos ], dtype = torch .int64 )),
101101 )
102102
103103 # Update the start_pos. start_pos is used in kv cache. The source of truth
@@ -109,8 +109,8 @@ def test_llava_export(self):
109109 llava_module .run_method (
110110 "text_decoder" ,
111111 (
112- torch .tensor ([start_pos ], dtype = torch .int64 ),
113112 pte_embeds_img ,
113+ torch .tensor ([start_pos ], dtype = torch .int64 ),
114114 ),
115115 )
116116
@@ -123,7 +123,7 @@ def test_llava_export(self):
123123 )[0 ]
124124 pte_prefill_after_img = llava_module .run_method (
125125 "text_decoder" ,
126- (torch .tensor ([start_pos ], dtype = torch .int64 ), pte_embeds_after_img ),
126+ (pte_embeds_after_img , torch .tensor ([start_pos ], dtype = torch .int64 )),
127127 )[0 ]
128128
129129 # Update the logits for each prefill (kv cache) step.
@@ -140,7 +140,7 @@ def test_llava_export(self):
140140 )[0 ]
141141 logits = llava_module .run_method (
142142 "text_decoder" ,
143- (torch .tensor ([start_pos + i ], dtype = torch .int64 ), token_embeds ),
143+ (token_embeds , torch .tensor ([start_pos + i ], dtype = torch .int64 )),
144144 )[0 ]
145145 new_tokens .append (torch .argmax (logits ).item ())
146146
0 commit comments