Skip to content

Commit 9fa7edf

Browse files
authored
Not intended for landing
1 parent 76a8906 commit 9fa7edf

15 files changed

+48
-810
lines changed

examples/models/llava/export_llava.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ def export_all(llava_model: LlavaModel):
226226
{
227227
"image_encoder": image_encoder_ep,
228228
"token_embedding": token_embedding_ep,
229-
"text_decoder": text_model_ep,
229+
"text_model": text_model_ep,
230230
},
231231
partitioner={
232232
"image_encoder": [XnnpackPartitioner()],
233-
"text_decoder": [
233+
"text_model": [
234234
# First partition the DQLinear nodes, then partition the rest of the nodes,
235235
# to avoid multiple DQLinear nodes in the same partition,
236236
# to avoid holding multiple unpacked and packed weight buffers in memory,
@@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel):
254254
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
255255
sym_shape_eval_pass={
256256
"image_encoder": ConstraintBasedSymShapeEvalPass(),
257-
"text_decoder": ConstraintBasedSymShapeEvalPass(),
257+
"text_model": ConstraintBasedSymShapeEvalPass(),
258258
"token_embedding": HintBasedSymShapeEvalPass(),
259259
},
260260
)

examples/models/llava/runner/llava_text_decoder_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner
8989
}
9090

9191
inline static const std::string kTokenEmbeddingMethod = "token_embedding";
92-
inline static const std::string kTextModelMethod = "text_decoder";
92+
inline static const std::string kTextModelMethod = "text_model";
9393
};
9494

9595
} // namespace example

examples/models/llava/test/test_llava.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_llava_export(self):
9696
"token_embedding", (prompt_before_image,)
9797
)[0]
9898
llava_module.run_method(
99-
"text_decoder",
99+
"text_model",
100100
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
101101
)
102102

@@ -107,7 +107,7 @@ def test_llava_export(self):
107107
# pte prefill image
108108
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
109109
llava_module.run_method(
110-
"text_decoder",
110+
"text_model",
111111
(
112112
torch.tensor([start_pos], dtype=torch.int64),
113113
pte_embeds_img,
@@ -122,7 +122,7 @@ def test_llava_export(self):
122122
"token_embedding", (prompt_after_image,)
123123
)[0]
124124
pte_prefill_after_img = llava_module.run_method(
125-
"text_decoder",
125+
"text_model",
126126
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
127127
)[0]
128128

@@ -139,7 +139,7 @@ def test_llava_export(self):
139139
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
140140
)[0]
141141
logits = llava_module.run_method(
142-
"text_decoder",
142+
"text_model",
143143
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
144144
)[0]
145145
new_tokens.append(torch.argmax(logits).item())

examples/models/llava/test/test_pte.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def main():
4747
"token_embedding", (prompt_before_image,)
4848
)[0]
4949
pte_prefill_before_img = llava_module.run_method(
50-
"text_decoder",
50+
"text_model",
5151
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
5252
)[0]
5353
print(pte_prefill_before_img)
@@ -60,7 +60,7 @@ def main():
6060
logging.warning("Image encoder finished")
6161
logging.warning("Image token prefill started")
6262
pte_prefill_img = llava_module.run_method(
63-
"text_decoder",
63+
"text_model",
6464
(
6565
torch.tensor([start_pos], dtype=torch.int64),
6666
pte_embeds_img,
@@ -77,7 +77,7 @@ def main():
7777
"token_embedding", (prompt_after_image,)
7878
)[0]
7979
pte_prefill_after_img = llava_module.run_method(
80-
"text_decoder",
80+
"text_model",
8181
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
8282
)[0]
8383
logging.warning("Text token prefill finished")
@@ -91,7 +91,7 @@ def main():
9191
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
9292
)[0]
9393
logits = llava_module.run_method(
94-
"text_decoder",
94+
"text_model",
9595
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
9696
)[0]
9797
new_tokens.append(torch.argmax(logits[..., -1, :]).item())

examples/models/voxtral/CMakeLists.txt

Lines changed: 0 additions & 99 deletions
This file was deleted.

0 commit comments

Comments
 (0)