Skip to content

Commit 63c5534

Browse files
committed
[Llava] Add max_context_len CLI arg
1 parent a1daab9 commit 63c5534

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

examples/models/llava/export_llava.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def create_llava_config_from_args(args):
281281
llm_config = LlmConfig()
282282

283283
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
284+
llm_config.export.max_context_length = args.max_context_len
284285
llm_config.export.max_seq_length = args.max_seq_len
285286
llm_config.export.output_name = args.pte_name
286287
llm_config.debug.profile_memory = args.profile_memory
@@ -296,6 +297,12 @@ def main():
296297
action=BooleanOptionalAction,
297298
help="Use sdpa_with_kv_cache custom op in LLava text model.",
298299
)
300+
parser.add_argument(
301+
"--max-context-len",
302+
default=768,
303+
type=int,
304+
help="Maximum context length for the text model.",
305+
)
299306
parser.add_argument(
300307
"--max-seq-len",
301308
default=768,
@@ -325,12 +332,13 @@ def main():
325332
llm_config = create_llava_config_from_args(args)
326333

327334
logging.info(
328-
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}"
335+
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}, max_context_len: {llm_config.export.max_context_length}"
329336
)
330337

331338
llava_model = LlavaModel(
332339
use_sdpa_with_kv_cache_op=llm_config.model.use_sdpa_with_kv_cache,
333340
max_seq_len=llm_config.export.max_seq_length,
341+
max_context_len=llm_config.export.max_context_length,
334342
)
335343

336344
executorch_program = export_all(llava_model)

examples/models/llava/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
llava_model: LlavaForConditionalGeneration,
6767
image_processor: CLIPImageProcessor,
6868
use_sdpa_with_kv_cache_op: bool = True,
69+
max_context_len: int = 768,
6970
max_seq_len: int = 768,
7071
):
7172
super().__init__()
@@ -87,6 +88,7 @@ def __init__(
8788
enable_dynamic_shape=True, # allow parallel prefill
8889
use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op, # use sdpa_with_kv_cache op
8990
use_hf_rope=True,
91+
max_context_len=max_context_len,
9092
max_seq_len=max_seq_len,
9193
)
9294
self.text_model = construct_transformer(self.text_model_args)

0 commit comments

Comments
 (0)