Skip to content

Commit ffaa4f4

Browse files
committed
[Llava] Add max_context_len CLI arg
1 parent a1daab9 commit ffaa4f4

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

examples/models/llava/export_llava.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def create_llava_config_from_args(args):
281281
llm_config = LlmConfig()
282282

283283
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
284+
llm_config.export.max_context_length = args.max_context_len
284285
llm_config.export.max_seq_length = args.max_seq_len
285286
llm_config.export.output_name = args.pte_name
286287
llm_config.debug.profile_memory = args.profile_memory
@@ -296,6 +297,12 @@ def main():
296297
action=BooleanOptionalAction,
297298
help="Use sdpa_with_kv_cache custom op in LLava text model.",
298299
)
300+
parser.add_argument(
301+
"--max-context-len",
302+
default=768,
303+
type=int,
304+
help="Maximum context length for the text model.",
305+
)
299306
parser.add_argument(
300307
"--max-seq-len",
301308
default=768,
@@ -325,12 +332,13 @@ def main():
325332
llm_config = create_llava_config_from_args(args)
326333

327334
logging.info(
328-
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}"
335+
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}, max_context_len: {llm_config.export.max_context_length}"
329336
)
330337

331338
llava_model = LlavaModel(
332339
use_sdpa_with_kv_cache_op=llm_config.model.use_sdpa_with_kv_cache,
333340
max_seq_len=llm_config.export.max_seq_length,
341+
max_context_len=llm_config.export.max_context_length,
334342
)
335343

336344
executorch_program = export_all(llava_model)

examples/models/llava/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
llava_model: LlavaForConditionalGeneration,
6767
image_processor: CLIPImageProcessor,
6868
use_sdpa_with_kv_cache_op: bool = True,
69+
max_context_len: int = 768,
6970
max_seq_len: int = 768,
7071
):
7172
super().__init__()
@@ -87,6 +88,7 @@ def __init__(
8788
enable_dynamic_shape=True, # allow parallel prefill
8889
use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op, # use sdpa_with_kv_cache op
8990
use_hf_rope=True,
91+
max_context_len=max_context_len,
9092
max_seq_len=max_seq_len,
9193
)
9294
self.text_model = construct_transformer(self.text_model_args)
@@ -300,8 +302,11 @@ def forward(
300302

301303

302304
class LlavaModel(EagerModelBase):
303-
def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768):
305+
def __init__(
306+
self, use_sdpa_with_kv_cache_op=True, max_seq_len=768, max_context_len=768
307+
):
304308
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
309+
self.max_context_len = max_context_len
305310
self.max_seq_len = max_seq_len
306311
self.model = LlavaForConditionalGeneration.from_pretrained(
307312
"llava-hf/llava-1.5-7b-hf",
@@ -348,6 +353,7 @@ def get_eager_model(self):
348353
self.model,
349354
self.image_processor,
350355
self.use_sdpa_with_kv_cache_op,
356+
self.max_context_len,
351357
self.max_seq_len,
352358
)
353359
model.to(dtype=torch.float32)

0 commit comments

Comments
 (0)