diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index f903e0f2ecf..7e571087c1d 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -77,7 +77,7 @@ def __init__(self, llava): super().__init__() self.text_model = llava.text_model - def forward(self, input_pos, embeddings): + def forward(self, embeddings, input_pos): return self.text_model(None, {"input_pos": input_pos}, embeddings) llava_text_model = LlavaTextModel(llava) @@ -88,7 +88,7 @@ def forward(self, input_pos, embeddings): max_seq_len=llava.text_model_args.max_seq_len, dtype=DType.fp32, use_kv_cache=True, - example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings), + example_inputs=(embeddings, torch.tensor([0], dtype=torch.int64)), dynamic_shapes=dynamic_shapes, ) diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 3973d756e9c..9ff56124174 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -405,5 +405,5 @@ def _get_image_dynamic_shapes(self): def _get_prompt_dynamic_shapes(self): dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len) - text_model_dynamic_shapes = ({0: 1}, {1: dim}) + text_model_dynamic_shapes = ({1: dim}, {0: 1}) return text_model_dynamic_shapes diff --git a/examples/models/llava/runner/llava_image_prefiller.h b/examples/models/llava/runner/llava_image_prefiller.h index 9edfab85904..f5f316d0cac 100644 --- a/examples/models/llava/runner/llava_image_prefiller.h +++ b/examples/models/llava/runner/llava_image_prefiller.h @@ -47,7 +47,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller { // Run text model auto outputs_res = ET_UNWRAP(module_->execute( - kTextModelMethod, {start_pos_tensor, image_encoder_outputs[0]})); + kTextModelMethod, {image_encoder_outputs[0], start_pos_tensor})); ET_CHECK_MSG( outputs_res[0].isTensor(), "Non Tensor Output returned from executing image prefill"); diff --git a/examples/models/llava/runner/llava_text_decoder_runner.h b/examples/models/llava/runner/llava_text_decoder_runner.h index cfa92e0c253..691e2f4aa1e 100644 --- a/examples/models/llava/runner/llava_text_decoder_runner.h +++ b/examples/models/llava/runner/llava_text_decoder_runner.h @@ -34,7 +34,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner &start_pos, {1}, executorch::aten::ScalarType::Long); // run text model auto outputs_res = ET_UNWRAP(module_->execute( - kTextModelMethod, {start_pos_tensor, token_embedding_outputs[0]})); + kTextModelMethod, {token_embedding_outputs[0], start_pos_tensor})); ET_CHECK_MSG( outputs_res.size() == 1, diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py index def9eaa02bd..7f2b59e0116 100644 --- a/examples/models/llava/test/test_llava.py +++ b/examples/models/llava/test/test_llava.py @@ -97,7 +97,7 @@ def test_llava_export(self): )[0] llava_module.run_method( "text_decoder", - (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), + (pte_embeds_before_img, torch.tensor([start_pos], dtype=torch.int64)), ) # Update the start_pos. start_pos is used in kv cache. The source of truth @@ -109,8 +109,8 @@ def test_llava_export(self): llava_module.run_method( "text_decoder", ( - torch.tensor([start_pos], dtype=torch.int64), pte_embeds_img, + torch.tensor([start_pos], dtype=torch.int64), ), ) @@ -123,7 +123,7 @@ def test_llava_export(self): )[0] pte_prefill_after_img = llava_module.run_method( "text_decoder", - (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), + (pte_embeds_after_img, torch.tensor([start_pos], dtype=torch.int64)), )[0] # Update the logits for each prefill (kv cache) step. @@ -140,7 +140,7 @@ def test_llava_export(self): )[0] logits = llava_module.run_method( "text_decoder", - (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), + (token_embeds, torch.tensor([start_pos + i], dtype=torch.int64)), )[0] new_tokens.append(torch.argmax(logits).item())