-
-
Couldn't load subscription status.
- Fork 1.2k
qa-lora integration #3013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
qa-lora integration #3013
Changes from all commits
8b34532
3104c6e
573992b
81995cc
0a9d202
2b0470f
ab40a64
da95f84
867e35f
366f02e
04cc91e
4e7ef01
f8baf32
0170f38
d3fcc44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
| model_type: LlamaForCausalLM | ||
| tokenizer_type: LlamaTokenizer | ||
|
|
||
| load_in_8bit: false | ||
| load_in_4bit: true | ||
|
|
||
| datasets: | ||
| - path: mhenrichsen/alpaca_2k_test | ||
| type: alpaca | ||
| dataset_prepared_path: | ||
| val_set_size: 0.05 | ||
| output_dir: ./outputs/qalora-out | ||
|
|
||
| # This will be the new adapter type once implemented | ||
| adapter: qalora | ||
| lora_model_dir: | ||
|
|
||
| sequence_len: 2048 | ||
| sample_packing: true | ||
| eval_sample_packing: false | ||
|
|
||
| qalora_group_size: 16 | ||
|
|
||
| lora_r: 32 | ||
| lora_alpha: 16 | ||
| lora_dropout: 0.05 | ||
| lora_target_linear: true | ||
|
|
||
| gradient_accumulation_steps: 4 | ||
| micro_batch_size: 2 | ||
| num_epochs: 1 | ||
| optimizer: paged_adamw_32bit | ||
| lr_scheduler: cosine | ||
| learning_rate: 0.0002 | ||
|
|
||
| bf16: auto | ||
| tf32: false | ||
|
|
||
| gradient_checkpointing: true | ||
| logging_steps: 1 | ||
| flash_attention: false | ||
|
|
||
| warmup_ratio: 0.1 | ||
| evals_per_epoch: 4 | ||
| saves_per_epoch: 1 | ||
| weight_decay: 0.0 | ||
| special_tokens: |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -541,6 +541,12 @@ class AxolotlInputConfig( | |
| "description": "Whether to use flash-attention rms norm implementation - advanced use only" | ||
| }, | ||
| ) | ||
| flash_attn_fuse_qkv: bool | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "Whether to fuse QKV projection into a single operation" | ||
| }, | ||
| ) | ||
|
Comment on lines
+544
to
+549
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainflash_attn_fuse_qkv is defined but appears unused; add wiring + guards Good addition, but I don’t see this flag being consumed anywhere (e.g., in PatchManager or _set_attention_config). Also consider guarding for:
Follow-ups:
Run to confirm usage across the repo: 🏁 Script executed: #!/bin/bash
rg -n --hidden -g '!**/dist/**' -C2 '\bflash_attn_fuse_qkv\b'Length of output: 855 Wire Right now this flag is only declared in • In your attention‐kernel configuration (e.g. in By wiring the flag where the FA2 kernels are configured and codifying these guards, users will have both compile-time schema safety and correct runtime behavior. |
||
| flash_attn_fuse_mlp: bool | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
|
|
@@ -1073,7 +1079,7 @@ def check_auto_enable_lora_kernels(cls, data): | |
| if data.get("rl"): | ||
| # RL trainers not tested so don't enable kernels by default | ||
| return data | ||
| if data.get("adapter") in ["lora", "qlora"]: | ||
| if data.get("adapter") in ["lora", "qlora", "qalora"]: | ||
| # Skip if already set, using unsloth optimizations, or using 8-bit | ||
| unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] | ||
| kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,8 @@ | |
|
|
||
| from pydantic import BaseModel, Field, field_validator, model_validator | ||
|
|
||
| from axolotl.utils.schemas.enums import AdapterEnum | ||
|
|
||
|
|
||
| class LoftQConfig(BaseModel): | ||
| """LoftQ configuration subset""" | ||
|
|
@@ -38,10 +40,10 @@ class LoraConfig(BaseModel): | |
| default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"} | ||
| ) | ||
|
|
||
| adapter: str | None = Field( | ||
| adapter: AdapterEnum | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model" | ||
| "description": "If you want to use 'lora' or 'qlora' or 'qalora' or leave blank to train all parameters in original model" | ||
| }, | ||
| ) | ||
| lora_model_dir: str | None = Field( | ||
|
|
@@ -128,6 +130,10 @@ class LoraConfig(BaseModel): | |
| "description": "loraplus learning rate for lora embedding layers. Default value is 1e-6." | ||
| }, | ||
| ) | ||
| qalora_group_size: int | None = Field( | ||
| default=None, | ||
| json_schema_extra={"description": "Group size for QALoRA quantization pooling"}, | ||
| ) | ||
|
|
||
| merge_lora: bool | None = None | ||
|
|
||
|
|
@@ -168,6 +174,28 @@ def validate_qlora(self): | |
|
|
||
| if not self.load_in_4bit: | ||
| raise ValueError("Require cfg.load_in_4bit to be True for qlora") | ||
|
|
||
| if self.adapter == "qalora": | ||
| if self.merge_lora: | ||
| # can't merge qalora if loaded in 8bit or 4bit | ||
| if self.load_in_8bit: | ||
| raise ValueError("Can't merge qalora if loaded in 8bit") | ||
|
|
||
| if self.gptq: | ||
| raise ValueError("Can't merge qalora if gptq") | ||
|
|
||
| if self.load_in_4bit: | ||
| raise ValueError("Can't merge qalora if loaded in 4bit") | ||
|
|
||
| else: | ||
| if self.load_in_8bit: | ||
| raise ValueError("Can't load qalora in 8bit") | ||
|
|
||
| if self.gptq: | ||
| raise ValueError("Can't load qalora if gptq") | ||
|
|
||
| if not self.load_in_4bit: | ||
| raise ValueError("Require cfg.load_in_4bit to be True for qalora") | ||
| return self | ||
|
Comment on lines
+178
to
199
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainMirror Enum comparison for qalora; add guard for stray qalora_group_size; consider clearer GPTQ limitation text.
Apply: - if self.adapter == "qalora":
+ if self.adapter == AdapterEnum.qalora:
if self.merge_lora:
# can't merge qalora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qalora if loaded in 8bit")
if self.gptq:
- raise ValueError("Can't merge qalora if gptq")
+ raise ValueError(
+ "Can't merge qalora if gptq (adapter merging into GPTQ is not supported in this PR; PEFT also does not support it)"
+ )
if self.load_in_4bit:
raise ValueError("Can't merge qalora if loaded in 4bit")
else:
if self.load_in_8bit:
raise ValueError("Can't load qalora in 8bit")
if self.gptq:
- raise ValueError("Can't load qalora if gptq")
+ raise ValueError("Can't load qalora if gptq")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qalora")
+ # Disallow stray qalora_group_size when adapter is not qalora
+ if self.adapter != AdapterEnum.qalora and self.qalora_group_size is not None:
+ raise ValueError("qalora_group_size is only valid when adapter == 'qalora'")To ensure consistency repo-wide (no lingering string literal checks; field is only used with qalora), run: 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "String literal adapter comparisons (should prefer AdapterEnum.*):"
rg -nP -C2 '\badapter\s*==\s*"(?:lora|qlora|qalora)"' --type=py || true
echo
echo "Usages of qalora_group_size:"
rg -n --type=py -C2 '\bqalora_group_size\b' || trueLength of output: 7018 Update QALoRA schema validation: use The validation logic in
Note that string-literal comparisons for adapters are still widespread across tests and other modules (e.g. 🤖 Prompt for AI Agents |
||
|
|
||
| @field_validator("loraplus_lr_embedding") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to raise an Error here or possibly in validation if a user does not provide this?