-
Notifications
You must be signed in to change notification settings - Fork 689
Export a lora model #11045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Export a lora model #11045
Changes from all commits
aac7fd6
c272cb3
65999f0
3887cad
d5e4d36
deaa4b6
751286a
6392f0b
fc15be6
bda599b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,10 +73,16 @@ class BaseConfig: | |
if it is a Llama model or the weights will be downloaded from HuggingFace | ||
if it is a non-Llama model. | ||
checkpoint_dir: Path to directory containing sharded checkpoint files. | ||
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if | ||
the model has trained LoRA adapters. Must provide | ||
adapter_config.json. | ||
adapter_config: Path to the adapter_config.json file from torchtune. | ||
Used if the model has trained LoRA adapters. Must provide adapter.pt. | ||
tokenizer_path: Path to the tokenizer file. | ||
metadata: Json string containing metadata information. | ||
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' | ||
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT. | ||
use_lora: Only for use with QAT. Rank of the LoRA adapter, disabled | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You didnt add this, but why is this boolean named field an int, and why does it correspond with qat? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @cccclai ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't add it myself either, I think it's likely from Lunwen - I believe it's for the llama 3.2 1b QAT checkpoint which include LoRA, so make sure we don't break llama3.2 QAT model if we use this flag somewhere else There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if set to 0. | ||
fairseq2: For legacy internal use cases, this is safe to ignore. | ||
preq_mode: Legacy option to specify how prequantized weights are loaded. | ||
Going forward, ExecuTorch supports loading weights prequantized through | ||
|
@@ -90,6 +96,8 @@ class BaseConfig: | |
params: Optional[str] = None | ||
checkpoint: Optional[str] = None | ||
checkpoint_dir: Optional[str] = None | ||
adapter_checkpoint: Optional[str] = None | ||
lucylq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
adapter_config: Optional[str] = None | ||
tokenizer_path: Optional[str] = None | ||
metadata: Optional[str] = None | ||
use_lora: int = 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use os.exists or similar from Path?