-
Notifications
You must be signed in to change notification settings - Fork 720
[Llava] Add max_context_len CLI arg #14599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/14599
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New Failures, 3 Pending, 2 Unrelated FailuresAs of commit 315ea97 with merge base a1daab9 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
63c5534 to
ffaa4f4
Compare
kimishpatel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would like to make max_context_len required arg and if it is not in export_llm I think it should be. Or at least improve documentation to include this arg in the export CLI example
ffaa4f4 to
315ea97
Compare
I've updated the arg to be required and made corresponding changes to the example README and test_llava.sh script. I'd recommend better documenting the different between the user facing max_context_len and max_seq_len args in the export_llava.py script, though I'm likely not the right owner for this. |
It is also a bit hard to explain the difference between the two unless user understands how to use it for better memory footprint. I would just opt for better default for max_seq_len |
larryliu0820
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix!
|
Noting that CI failures appear to be pre-existing or flaky. Merging. |
|
I'll submit a pick request once the trunk jobs complete. |
|
@pytorchbot cherry-pick --onto release/1.0 -c regression |
### Summary Add a required max_context_len argument to the Llava example model export. When set to 768, this reduces the memory consumption (~6GiB -> ~4.8GiB RSS) at the cost of a smaller context length and thus fixes #14474. ### Test plan Ran ./test_llava.sh and validated the reported memory consumption on an x86 Linux machine. ``` I 00:00:18.433471 executorch:main.cpp:172] Starting generation... I 00:00:18.433500 executorch:multimodal_runner.cpp:95] RSS after loading model: 4746.726562 MiB (0 if unsupported) I 00:00:18.433554 executorch:multimodal_runner.cpp:119] Prefilling input 0/3, type: text I 00:00:19.484581 executorch:multimodal_runner.cpp:119] Prefilling input 1/3, type: image I 00:00:19.484710 executorch:multimodal_prefiller.cpp:83] Image tensor dim: 3, dtype: Byte I 00:00:30.442685 executorch:multimodal_runner.cpp:119] Prefilling input 2/3, type: text I 00:00:30.951938 executorch:multimodal_runner.cpp:138] RSS after multimodal input processing: 4847.933594 MiB (0 if unsupported) I 00:00:30.952000 executorch:multimodal_runner.cpp:148] Max new tokens resolved: 153, pos_ 615, max_context_len 768 ``` (cherry picked from commit bc755c6)
Cherry picking #14599The cherry pick PR is at #15112 and it is recommended to link a regression cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
Summary
Add a required max_context_len argument to the Llava example model export. When set to 768, this reduces the memory consumption (~6GiB -> ~4.8GiB RSS) at the cost of a smaller context length and thus fixes #14474.
Test plan
Ran ./test_llava.sh and validated the reported memory consumption on an x86 Linux machine.