Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/test_llava.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ cmake_build_llava_runner_for_android() {
# only export the one without custom op for now since it's
export_llava() {
echo "Starting to export Llava. This will take about 6 mins"
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts --max-context-len 768
}

# Download a new image
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Prerequisite: run `install_executorch.sh` to install ExecuTorch and run
`examples/models/llava/install_requirements.sh` to install dependencies.

```bash
python -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts
python -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts --max-context-len=768
```

Currently the whole export process takes about 6 minutes. We also provide a
Expand Down
10 changes: 9 additions & 1 deletion examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def create_llava_config_from_args(args):
llm_config = LlmConfig()

llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
llm_config.export.max_context_length = args.max_context_len
llm_config.export.max_seq_length = args.max_seq_len
llm_config.export.output_name = args.pte_name
llm_config.debug.profile_memory = args.profile_memory
Expand All @@ -296,6 +297,12 @@ def main():
action=BooleanOptionalAction,
help="Use sdpa_with_kv_cache custom op in LLava text model.",
)
parser.add_argument(
"--max-context-len",
required=True,
type=int,
help="Maximum context length for the text model.",
)
parser.add_argument(
"--max-seq-len",
default=768,
Expand Down Expand Up @@ -325,12 +332,13 @@ def main():
llm_config = create_llava_config_from_args(args)

logging.info(
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}"
f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}, max_context_len: {llm_config.export.max_context_length}"
)

llava_model = LlavaModel(
use_sdpa_with_kv_cache_op=llm_config.model.use_sdpa_with_kv_cache,
max_seq_len=llm_config.export.max_seq_length,
max_context_len=llm_config.export.max_context_length,
)

executorch_program = export_all(llava_model)
Expand Down
8 changes: 7 additions & 1 deletion examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
llava_model: LlavaForConditionalGeneration,
image_processor: CLIPImageProcessor,
use_sdpa_with_kv_cache_op: bool = True,
max_context_len: int = 768,
max_seq_len: int = 768,
):
super().__init__()
Expand All @@ -87,6 +88,7 @@ def __init__(
enable_dynamic_shape=True, # allow parallel prefill
use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op, # use sdpa_with_kv_cache op
use_hf_rope=True,
max_context_len=max_context_len,
max_seq_len=max_seq_len,
)
self.text_model = construct_transformer(self.text_model_args)
Expand Down Expand Up @@ -300,8 +302,11 @@ def forward(


class LlavaModel(EagerModelBase):
def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768):
def __init__(
self, use_sdpa_with_kv_cache_op=True, max_seq_len=768, max_context_len=768
):
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
self.max_context_len = max_context_len
self.max_seq_len = max_seq_len
self.model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
Expand Down Expand Up @@ -348,6 +353,7 @@ def get_eager_model(self):
self.model,
self.image_processor,
self.use_sdpa_with_kv_cache_op,
self.max_context_len,
self.max_seq_len,
)
model.to(dtype=torch.float32)
Expand Down
Loading