- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.2k
qa-lora integration #3013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
qa-lora integration #3013
Changes from 12 commits
8b34532
              3104c6e
              573992b
              81995cc
              0a9d202
              2b0470f
              ab40a64
              da95f84
              867e35f
              366f02e
              04cc91e
              4e7ef01
              f8baf32
              0170f38
              d3fcc44
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 | ||
| model_type: LlamaForCausalLM | ||
| tokenizer_type: LlamaTokenizer | ||
|  | ||
| load_in_8bit: false | ||
| load_in_4bit: true | ||
|  | ||
| datasets: | ||
| - path: mhenrichsen/alpaca_2k_test | ||
| type: alpaca | ||
| dataset_prepared_path: | ||
| val_set_size: 0.05 | ||
| output_dir: ./outputs/qalora-out | ||
|  | ||
| # This will be the new adapter type once implemented | ||
| adapter: qalora | ||
| lora_model_dir: | ||
|  | ||
| sequence_len: 2048 | ||
| sample_packing: true | ||
| eval_sample_packing: false | ||
|  | ||
| qalora_group_size: 16 | ||
|  | ||
| lora_r: 32 | ||
| lora_alpha: 16 | ||
| lora_dropout: 0.05 | ||
| lora_target_linear: true | ||
|  | ||
| gradient_accumulation_steps: 4 | ||
| micro_batch_size: 2 | ||
| num_epochs: 1 | ||
| optimizer: paged_adamw_32bit | ||
| lr_scheduler: cosine | ||
| learning_rate: 0.0002 | ||
|  | ||
| bf16: auto | ||
| tf32: false | ||
|  | ||
| gradient_checkpointing: true | ||
| logging_steps: 1 | ||
| flash_attention: false | ||
|  | ||
| warmup_ratio: 0.1 | ||
| evals_per_epoch: 4 | ||
| saves_per_epoch: 1 | ||
| weight_decay: 0.0 | ||
| special_tokens: | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -4,6 +4,8 @@ | |
|  | ||
| from pydantic import BaseModel, Field, field_validator, model_validator | ||
|  | ||
| from axolotl.utils.schemas.enums import AdapterEnum | ||
|  | ||
|  | ||
| class LoftQConfig(BaseModel): | ||
| """LoftQ configuration subset""" | ||
|  | @@ -38,10 +40,10 @@ class LoraConfig(BaseModel): | |
| default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"} | ||
| ) | ||
|  | ||
| adapter: str | None = Field( | ||
| adapter: AdapterEnum | None = Field( | ||
| default=None, | ||
| json_schema_extra={ | ||
| "description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model" | ||
| "description": "If you want to use 'lora' or 'qlora' or 'qalora' or leave blank to train all parameters in original model" | ||
| }, | ||
| ) | ||
| lora_model_dir: str | None = Field( | ||
|  | @@ -128,6 +130,10 @@ class LoraConfig(BaseModel): | |
| "description": "loraplus learning rate for lora embedding layers. Default value is 1e-6." | ||
| }, | ||
| ) | ||
| qalora_group_size: int | None = Field( | ||
| default=None, | ||
| json_schema_extra={"description": "Group size for QALoRA quantization pooling"}, | ||
| ) | ||
|  | ||
| merge_lora: bool | None = None | ||
|  | ||
|  | @@ -168,6 +174,28 @@ def validate_qlora(self): | |
|  | ||
| if not self.load_in_4bit: | ||
| raise ValueError("Require cfg.load_in_4bit to be True for qlora") | ||
|  | ||
| if self.adapter == "qalora": | ||
| if self.merge_lora: | ||
| # can't merge qalora if loaded in 8bit or 4bit | ||
| if self.load_in_8bit: | ||
| raise ValueError("Can't merge qalora if loaded in 8bit") | ||
|  | ||
| if self.gptq: | ||
| raise ValueError("Can't merge qalora if gptq") | ||
|  | ||
| if self.load_in_4bit: | ||
| raise ValueError("Can't merge qalora if loaded in 4bit") | ||
|  | ||
| else: | ||
| if self.load_in_8bit: | ||
| raise ValueError("Can't load qalora in 8bit") | ||
|  | ||
| if self.gptq: | ||
| raise ValueError("Can't load qalora if gptq") | ||
|  | ||
| if not self.load_in_4bit: | ||
| raise ValueError("Require cfg.load_in_4bit to be True for qalora") | ||
| return self | ||
| 
      Comment on lines
    
      +178
     to 
      199
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainMirror Enum comparison for qalora; add guard for stray qalora_group_size; consider clearer GPTQ limitation text. 
 Apply: -        if self.adapter == "qalora":
+        if self.adapter == AdapterEnum.qalora:
             if self.merge_lora:
                 # can't merge qalora if loaded in 8bit or 4bit
                 if self.load_in_8bit:
                     raise ValueError("Can't merge qalora if loaded in 8bit")
 
                 if self.gptq:
-                    raise ValueError("Can't merge qalora if gptq")
+                    raise ValueError(
+                        "Can't merge qalora if gptq (adapter merging into GPTQ is not supported in this PR; PEFT also does not support it)"
+                    )
 
                 if self.load_in_4bit:
                     raise ValueError("Can't merge qalora if loaded in 4bit")
 
             else:
                 if self.load_in_8bit:
                     raise ValueError("Can't load qalora in 8bit")
 
                 if self.gptq:
-                    raise ValueError("Can't load qalora if gptq")
+                    raise ValueError("Can't load qalora if gptq")
 
                 if not self.load_in_4bit:
                     raise ValueError("Require cfg.load_in_4bit to be True for qalora")
+        # Disallow stray qalora_group_size when adapter is not qalora
+        if self.adapter != AdapterEnum.qalora and self.qalora_group_size is not None:
+            raise ValueError("qalora_group_size is only valid when adapter == 'qalora'")To ensure consistency repo-wide (no lingering string literal checks; field is only used with qalora), run: 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "String literal adapter comparisons (should prefer AdapterEnum.*):"
rg -nP -C2 '\badapter\s*==\s*"(?:lora|qlora|qalora)"' --type=py || true
echo
echo "Usages of qalora_group_size:"
rg -n --type=py -C2 '\bqalora_group_size\b' || trueLength of output: 7018 Update QALoRA schema validation: use  The validation logic in  
 Note that string-literal comparisons for adapters are still widespread across tests and other modules (e.g.  🤖 Prompt for AI Agents | ||
|  | ||
| @field_validator("loraplus_lr_embedding") | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -574,6 +574,21 @@ def check_lora_axolotl_unsloth(cls, data): | |
|  | ||
| @model_validator(mode="after") | ||
| def check_fused_lora(self): | ||
| if self.adapter in ["lora", "qlora", "qalora"] and ( | ||
| self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp | ||
| ): | ||
| raise ValueError("Fused modules are not supported with LoRA/QLoRA/QALoRA") | ||
| return self | ||
|          | ||
|  | ||
| @model_validator(mode="after") | ||
| def validate_qalora(self): | ||
| if self.adapter == "qalora": | ||
| if not self.load_in_4bit: | ||
| raise ValueError("QALoRA requires load_in_4bit to be True") | ||
| if not hasattr(self, "qalora_group_size") or self.qalora_group_size is None: | ||
| raise ValueError("QALoRA requires qalora_group_size to be specified") | ||
| if self.merge_lora: | ||
| raise ValueError("QALoRA does not support merge_lora yet") | ||
| if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp: | ||
| raise ValueError("Fused modules are not supported with LoRA/QLoRA") | ||
| return self | ||
|  | @@ -582,7 +597,7 @@ def check_fused_lora(self): | |
| @classmethod | ||
| def warn_qlora_zero3_w_use_reentrant(cls, data): | ||
| if ( | ||
| data.get("adapter") == "qlora" | ||
| data.get("adapter") in ["qlora", "qalora"] | ||
| and data.get("gradient_checkpointing_kwargs", {}) | ||
| and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") | ||
| is False | ||
|  | @@ -688,7 +703,7 @@ def check_rl_config_gradient_checkpointing(cls, data): | |
| and data.get("gradient_checkpointing_kwargs") | ||
| and data.get("gradient_checkpointing_kwargs").get("use_reentrant") | ||
| and data.get("load_in_4bit") | ||
| and data.get("adapter") == "qlora" | ||
| and data.get("adapter") in ["qlora", "qalora"] | ||
| and data.get("capabilities") | ||
| and data.get("capabilities").get("n_gpu", 1) > 1 | ||
| ): | ||
|  | @@ -1186,8 +1201,10 @@ def check_relora(self): | |
| if self.relora: | ||
| if not self.jagged_restart_steps: | ||
| raise ValueError("jagged_restart_steps must be set to use ReLoRA") | ||
| if self.adapter not in ("lora", "qlora"): | ||
| raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") | ||
| if self.adapter not in ("lora", "qlora", "qalora"): | ||
| raise ValueError( | ||
| "cfg.adapter must be lora, qlora, or qalora to use ReLoRA" | ||
| ) | ||
|  | ||
| if self.fsdp or self.fsdp_config: | ||
| raise ValueError("fsdp not supported with ReLoRA") | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to raise an Error here or possibly in validation if a user does not provide this?