Skip to content

Commit bc755c6

Browse files
authored
[Llava] Add max_context_len CLI arg (#14599)
### Summary Add a required max_context_len argument to the Llava example model export. When set to 768, this reduces the memory consumption (~6GiB -> ~4.8GiB RSS) at the cost of a smaller context length and thus fixes #14474. ### Test plan Ran ./test_llava.sh and validated the reported memory consumption on an x86 Linux machine. ``` I 00:00:18.433471 executorch:main.cpp:172] Starting generation... I 00:00:18.433500 executorch:multimodal_runner.cpp:95] RSS after loading model: 4746.726562 MiB (0 if unsupported) I 00:00:18.433554 executorch:multimodal_runner.cpp:119] Prefilling input 0/3, type: text I 00:00:19.484581 executorch:multimodal_runner.cpp:119] Prefilling input 1/3, type: image I 00:00:19.484710 executorch:multimodal_prefiller.cpp:83] Image tensor dim: 3, dtype: Byte I 00:00:30.442685 executorch:multimodal_runner.cpp:119] Prefilling input 2/3, type: text I 00:00:30.951938 executorch:multimodal_runner.cpp:138] RSS after multimodal input processing: 4847.933594 MiB (0 if unsupported) I 00:00:30.952000 executorch:multimodal_runner.cpp:148] Max new tokens resolved: 153, pos_ 615, max_context_len 768 ```
1 parent bef9555 commit bc755c6

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

.ci/scripts/test_llava.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ cmake_build_llava_runner_for_android() {
107107
# only export the one without custom op for now since it's
108108
export_llava() {
109109
echo "Starting to export Llava. This will take about 6 mins"
110-
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts
110+
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts --max-context-len 768
111111
}
112112

113113
# Download a new image

examples/models/llava/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Prerequisite: run `install_executorch.sh` to install ExecuTorch and run
4848
`examples/models/llava/install_requirements.sh` to install dependencies.
4949

5050
```bash
51-
python -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts
51+
python -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts --max-context-len=768
5252
```
5353

5454
Currently the whole export process takes about 6 minutes. We also provide a

examples/models/llava/export_llava.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def create_llava_config_from_args(args):
281281
llm_config = LlmConfig()
282282

283283
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
284+
llm_config.export.max_context_length = args.max_context_len
284285
llm_config.export.max_seq_length = args.max_seq_len
285286
llm_config.export.output_name = args.pte_name
286287
llm_config.debug.profile_memory = args.profile_memory
@@ -296,6 +297,12 @@ def main():
296297
action=BooleanOptionalAction,
297298
help="Use sdpa_with_kv_cache custom op in LLava text model.",
298299
)
300+
parser.add_argument(
301+
"--max-context-len",
302+
required=True,
303+
type=int,
304+
help="Maximum context length for the text model.",
305+
)
299306
parser.add_argument(
300307
"--max-seq-len",
301308
default=768,
@@ -325,12 +332,13 @@ def main():
325332
llm_config = create_llava_config_from_args(args)
326333

327334
logging.info(
328-
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}"
335+
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}, max_context_len: {llm_config.export.max_context_length}"
329336
)
330337

331338
llava_model = LlavaModel(
332339
use_sdpa_with_kv_cache_op=llm_config.model.use_sdpa_with_kv_cache,
333340
max_seq_len=llm_config.export.max_seq_length,
341+
max_context_len=llm_config.export.max_context_length,
334342
)
335343

336344
executorch_program = export_all(llava_model)

examples/models/llava/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
llava_model: LlavaForConditionalGeneration,
6767
image_processor: CLIPImageProcessor,
6868
use_sdpa_with_kv_cache_op: bool = True,
69+
max_context_len: int = 768,
6970
max_seq_len: int = 768,
7071
):
7172
super().__init__()
@@ -87,6 +88,7 @@ def __init__(
8788
enable_dynamic_shape=True, # allow parallel prefill
8889
use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op, # use sdpa_with_kv_cache op
8990
use_hf_rope=True,
91+
max_context_len=max_context_len,
9092
max_seq_len=max_seq_len,
9193
)
9294
self.text_model = construct_transformer(self.text_model_args)
@@ -300,8 +302,11 @@ def forward(
300302

301303

302304
class LlavaModel(EagerModelBase):
303-
def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768):
305+
def __init__(
306+
self, use_sdpa_with_kv_cache_op=True, max_seq_len=768, max_context_len=768
307+
):
304308
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
309+
self.max_context_len = max_context_len
305310
self.max_seq_len = max_seq_len
306311
self.model = LlavaForConditionalGeneration.from_pretrained(
307312
"llava-hf/llava-1.5-7b-hf",
@@ -348,6 +353,7 @@ def get_eager_model(self):
348353
self.model,
349354
self.image_processor,
350355
self.use_sdpa_with_kv_cache_op,
356+
self.max_context_len,
351357
self.max_seq_len,
352358
)
353359
model.to(dtype=torch.float32)

0 commit comments

Comments
 (0)