Skip to content

Commit 2dd30f5

Browse files
committed
Swap Llava export arg order
1 parent 6d8583d commit 2dd30f5

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

examples/models/llava/export_llava.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self, llava):
7777
super().__init__()
7878
self.text_model = llava.text_model
7979

80-
def forward(self, input_pos, embeddings):
80+
def forward(self, embeddings, input_pos):
8181
return self.text_model(None, {"input_pos": input_pos}, embeddings)
8282

8383
llava_text_model = LlavaTextModel(llava)
@@ -88,7 +88,7 @@ def forward(self, input_pos, embeddings):
8888
max_seq_len=llava.text_model_args.max_seq_len,
8989
dtype=DType.fp32,
9090
use_kv_cache=True,
91-
example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
91+
example_inputs=(embeddings, torch.tensor([0], dtype=torch.int64)),
9292
dynamic_shapes=dynamic_shapes,
9393
)
9494

examples/models/llava/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,5 +405,5 @@ def _get_image_dynamic_shapes(self):
405405

406406
def _get_prompt_dynamic_shapes(self):
407407
dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len)
408-
text_model_dynamic_shapes = ({0: 1}, {1: dim})
408+
text_model_dynamic_shapes = ({1: dim}, {0: 1})
409409
return text_model_dynamic_shapes

0 commit comments

Comments
 (0)