Skip to content

Commit 23acfea

Browse files
authored
Swap Llava export arg order (#14238)
Swaps Llava export arg order, such that forward takes `embeddings, cache_position`, instead of `cache_position, embeddings`.
1 parent a77c8df commit 23acfea

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

examples/models/llava/export_llava.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self, llava):
7777
super().__init__()
7878
self.text_model = llava.text_model
7979

80-
def forward(self, input_pos, embeddings):
80+
def forward(self, embeddings, input_pos):
8181
return self.text_model(None, {"input_pos": input_pos}, embeddings)
8282

8383
llava_text_model = LlavaTextModel(llava)
@@ -88,7 +88,7 @@ def forward(self, input_pos, embeddings):
8888
max_seq_len=llava.text_model_args.max_seq_len,
8989
dtype=DType.fp32,
9090
use_kv_cache=True,
91-
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
91+
example_inputs=(embeddings, torch.tensor([0], dtype=torch.int64)),
9292
dynamic_shapes=dynamic_shapes,
9393
)
9494

examples/models/llava/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,5 +405,5 @@ def _get_image_dynamic_shapes(self):
405405

406406
def _get_prompt_dynamic_shapes(self):
407407
dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len)
408-
text_model_dynamic_shapes = ({0: 1}, {1: dim})
408+
text_model_dynamic_shapes = ({1: dim}, {0: 1})
409409
return text_model_dynamic_shapes

examples/models/llava/runner/llava_image_prefiller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller {
4747

4848
// Run text model
4949
auto outputs_res = ET_UNWRAP(module_->execute(
50-
kTextModelMethod, {start_pos_tensor, image_encoder_outputs[0]}));
50+
kTextModelMethod, {image_encoder_outputs[0], start_pos_tensor}));
5151
ET_CHECK_MSG(
5252
outputs_res[0].isTensor(),
5353
"Non Tensor Output returned from executing image prefill");

examples/models/llava/runner/llava_text_decoder_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner
3434
&start_pos, {1}, executorch::aten::ScalarType::Long);
3535
// run text model
3636
auto outputs_res = ET_UNWRAP(module_->execute(
37-
kTextModelMethod, {start_pos_tensor, token_embedding_outputs[0]}));
37+
kTextModelMethod, {token_embedding_outputs[0], start_pos_tensor}));
3838

3939
ET_CHECK_MSG(
4040
outputs_res.size() == 1,

examples/models/llava/test/test_llava.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_llava_export(self):
9797
)[0]
9898
llava_module.run_method(
9999
"text_decoder",
100-
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
100+
(pte_embeds_before_img, torch.tensor([start_pos], dtype=torch.int64)),
101101
)
102102

103103
# Update the start_pos. start_pos is used in kv cache. The source of truth
@@ -109,8 +109,8 @@ def test_llava_export(self):
109109
llava_module.run_method(
110110
"text_decoder",
111111
(
112-
torch.tensor([start_pos], dtype=torch.int64),
113112
pte_embeds_img,
113+
torch.tensor([start_pos], dtype=torch.int64),
114114
),
115115
)
116116

@@ -123,7 +123,7 @@ def test_llava_export(self):
123123
)[0]
124124
pte_prefill_after_img = llava_module.run_method(
125125
"text_decoder",
126-
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
126+
(pte_embeds_after_img, torch.tensor([start_pos], dtype=torch.int64)),
127127
)[0]
128128

129129
# Update the logits for each prefill (kv cache) step.
@@ -140,7 +140,7 @@ def test_llava_export(self):
140140
)[0]
141141
logits = llava_module.run_method(
142142
"text_decoder",
143-
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
143+
(token_embeds, torch.tensor([start_pos + i], dtype=torch.int64)),
144144
)[0]
145145
new_tokens.append(torch.argmax(logits).item())
146146

0 commit comments

Comments
 (0)