Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/llama-2/qalora.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: true

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/qalora-out

# This will be the new adapter type once implemented
adapter: qalora
lora_model_dir:

sequence_len: 2048
sample_packing: true
eval_sample_packing: false

qalora_group_size: 16

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002

bf16: auto
tf32: false

gradient_checkpointing: true
logging_steps: 1
flash_attention: false

warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
1 change: 1 addition & 0 deletions src/axolotl/core/builders/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def build(self, total_num_steps):
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True


# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
Expand Down
11 changes: 10 additions & 1 deletion src/axolotl/loaders/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def load_lora(
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"

if cfg.adapter == "qalora":
lora_config_kwargs["use_qalora"] = True

if hasattr(cfg, "qalora_group_size") and cfg.qalora_group_size:
lora_config_kwargs["qalora_group_size"] = cfg.qalora_group_size
Comment on lines +96 to +97
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to raise an Error here or possibly in validation if a user does not provide this?

else:
ValueError("qalora_group_size must be set when using qalora")

if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
Expand Down Expand Up @@ -170,7 +179,7 @@ def load_adapter(
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
if adapter in ["lora", "qlora", "qalora"]:
peft_model, lora_config = load_lora(model, cfg, inference=inference)
return peft_model, lora_config
if adapter == "llama-adapter":
Expand Down
13 changes: 8 additions & 5 deletions src/axolotl/loaders/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def is_fsdp_enabled(self):

@property
def is_qlora_and_fsdp_enabled(self):
"""Property that determines if FSDP with QLoRA is enabled."""
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
"""Property that determines if FSDP with QLoRA/QALoRA is enabled."""
return self.is_fsdp_enabled and self.cfg.adapter in ["qlora", "qalora"]

def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.
Expand Down Expand Up @@ -318,7 +318,7 @@ def _configure_embedding_dtypes(self):

# Apply gradient checkpointing if needed
needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled
if self.cfg.adapter in ["lora", "qlora"]:
if self.cfg.adapter in ["lora", "qlora", "qalora"]:
needs_fa2_dtype = True
if self.cfg.gradient_checkpointing:
self.model.gradient_checkpointing_enable(
Expand Down Expand Up @@ -533,7 +533,7 @@ def _set_quantization_config(self):
**self.model_config.quantization_config
)
if (
self.cfg.adapter in ["qlora", "lora"]
self.cfg.adapter in ["qlora", "lora", "qalora"]
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
Expand All @@ -552,6 +552,9 @@ def _set_quantization_config(self):
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif (
self.cfg.adapter in ["qlora", "qalora"]
and self.model_kwargs["load_in_4bit"]
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
"load_in_4bit", False
):
Expand Down Expand Up @@ -859,7 +862,7 @@ def _prepare_model_for_quantization(self):

if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and self.cfg.adapter in ["lora", "qlora", "qalora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ def check_auto_enable_lora_kernels(cls, data):
if data.get("rl"):
# RL trainers not tested so don't enable kernels by default
return data
if data.get("adapter") in ["lora", "qlora"]:
if data.get("adapter") in ["lora", "qlora", "qalora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/utils/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ class TorchIntDType(Enum):
int8 = getattr(torch, "int8", None)


class AdapterEnum(str, Enum):
"""Adapter type configuration subset"""

lora = "lora"
qlora = "qlora"
qalora = "qalora"


class RLType(str, Enum):
"""RL trainer type configuration subset"""

Expand Down
32 changes: 30 additions & 2 deletions src/axolotl/utils/schemas/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from pydantic import BaseModel, Field, field_validator, model_validator

from axolotl.utils.schemas.enums import AdapterEnum


class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
Expand Down Expand Up @@ -38,10 +40,10 @@ class LoraConfig(BaseModel):
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
)

adapter: str | None = Field(
adapter: AdapterEnum | None = Field(
default=None,
json_schema_extra={
"description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model"
"description": "If you want to use 'lora' or 'qlora' or 'qalora' or leave blank to train all parameters in original model"
},
)
lora_model_dir: str | None = Field(
Expand Down Expand Up @@ -128,6 +130,10 @@ class LoraConfig(BaseModel):
"description": "loraplus learning rate for lora embedding layers. Default value is 1e-6."
},
)
qalora_group_size: int | None = Field(
default=None,
json_schema_extra={"description": "Group size for QALoRA quantization pooling"},
)

merge_lora: bool | None = None

Expand Down Expand Up @@ -168,6 +174,28 @@ def validate_qlora(self):

if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")

if self.adapter == "qalora":
if self.merge_lora:
# can't merge qalora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qalora if loaded in 8bit")

if self.gptq:
raise ValueError("Can't merge qalora if gptq")

if self.load_in_4bit:
raise ValueError("Can't merge qalora if loaded in 4bit")

else:
if self.load_in_8bit:
raise ValueError("Can't load qalora in 8bit")

if self.gptq:
raise ValueError("Can't load qalora if gptq")

if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qalora")
return self
Comment on lines +178 to 199
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Mirror Enum comparison for qalora; add guard for stray qalora_group_size; consider clearer GPTQ limitation text.

  • Use AdapterEnum for comparisons.
  • Prevent configs that set qalora_group_size while adapter != qalora.
  • Optional: Expand GPTQ error text to explicitly state “not supported by this PR/PEFT parity” to match the PR objectives and avoid confusion.

Apply:

-        if self.adapter == "qalora":
+        if self.adapter == AdapterEnum.qalora:
             if self.merge_lora:
                 # can't merge qalora if loaded in 8bit or 4bit
                 if self.load_in_8bit:
                     raise ValueError("Can't merge qalora if loaded in 8bit")
 
                 if self.gptq:
-                    raise ValueError("Can't merge qalora if gptq")
+                    raise ValueError(
+                        "Can't merge qalora if gptq (adapter merging into GPTQ is not supported in this PR; PEFT also does not support it)"
+                    )
 
                 if self.load_in_4bit:
                     raise ValueError("Can't merge qalora if loaded in 4bit")
 
             else:
                 if self.load_in_8bit:
                     raise ValueError("Can't load qalora in 8bit")
 
                 if self.gptq:
-                    raise ValueError("Can't load qalora if gptq")
+                    raise ValueError("Can't load qalora if gptq")
 
                 if not self.load_in_4bit:
                     raise ValueError("Require cfg.load_in_4bit to be True for qalora")
+        # Disallow stray qalora_group_size when adapter is not qalora
+        if self.adapter != AdapterEnum.qalora and self.qalora_group_size is not None:
+            raise ValueError("qalora_group_size is only valid when adapter == 'qalora'")

To ensure consistency repo-wide (no lingering string literal checks; field is only used with qalora), run:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "String literal adapter comparisons (should prefer AdapterEnum.*):"
rg -nP -C2 '\badapter\s*==\s*"(?:lora|qlora|qalora)"' --type=py || true

echo
echo "Usages of qalora_group_size:"
rg -n --type=py -C2 '\bqalora_group_size\b' || true

Length of output: 7018


Update QALoRA schema validation: use AdapterEnum and guard stray qalora_group_size

The validation logic in src/axolotl/utils/schemas/peft.py should be tightened and made consistent:

  • In the validate_qalora (lines 178–199):
    • Replace string-literal checks with the enum comparison:
      -    if self.adapter == "qalora":
      +    if self.adapter == AdapterEnum.qalora:
    • Add a guard to prevent a non-QALoRA adapter from carrying a qalora_group_size:
          # Disallow stray qalora_group_size when adapter is not qalora
          if self.adapter != AdapterEnum.qalora and self.qalora_group_size is not None:
              raise ValueError(
                  "qalora_group_size is only valid when adapter == AdapterEnum.qalora"
              )
  • Optional: Clarify the GPTQ limitation message in the merge branch to reflect “not supported by this PR/PEFT parity.”

Note that string-literal comparisons for adapters are still widespread across tests and other modules (e.g. tests/test_loaders.py, src/axolotl/loaders/adapter.py, etc.). We recommend migrating all of these to AdapterEnum in a follow-up refactor for consistency.

🤖 Prompt for AI Agents
In src/axolotl/utils/schemas/peft.py around lines 178–199, update
validate_qalora to use AdapterEnum instead of string literals when checking for
qalora, and add a guard that raises a ValueError if self.adapter !=
AdapterEnum.qalora and self.qalora_group_size is not None (message:
"qalora_group_size is only valid when adapter == AdapterEnum.qalora"). Also
change the existing qalora-related checks to compare against AdapterEnum.qalora,
and adjust the GPTQ/merge error message in the merge branch to indicate GPTQ is
not supported by this PR/PEFT parity.


@field_validator("loraplus_lr_embedding")
Expand Down
25 changes: 21 additions & 4 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,21 @@ def check_lora_axolotl_unsloth(cls, data):

@model_validator(mode="after")
def check_fused_lora(self):
if self.adapter in ["lora", "qlora", "qalora"] and (
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
):
raise ValueError("Fused modules are not supported with LoRA/QLoRA/QALoRA")
return self
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a merge conflict as we removed flash_attn_fuse_qkv recently.


@model_validator(mode="after")
def validate_qalora(self):
if self.adapter == "qalora":
if not self.load_in_4bit:
raise ValueError("QALoRA requires load_in_4bit to be True")
if not hasattr(self, "qalora_group_size") or self.qalora_group_size is None:
raise ValueError("QALoRA requires qalora_group_size to be specified")
if self.merge_lora:
raise ValueError("QALoRA does not support merge_lora yet")
if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
return self
Expand All @@ -582,7 +597,7 @@ def check_fused_lora(self):
@classmethod
def warn_qlora_zero3_w_use_reentrant(cls, data):
if (
data.get("adapter") == "qlora"
data.get("adapter") in ["qlora", "qalora"]
and data.get("gradient_checkpointing_kwargs", {})
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is False
Expand Down Expand Up @@ -688,7 +703,7 @@ def check_rl_config_gradient_checkpointing(cls, data):
and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
and data.get("load_in_4bit")
and data.get("adapter") == "qlora"
and data.get("adapter") in ["qlora", "qalora"]
and data.get("capabilities")
and data.get("capabilities").get("n_gpu", 1) > 1
):
Expand Down Expand Up @@ -1186,8 +1201,10 @@ def check_relora(self):
if self.relora:
if not self.jagged_restart_steps:
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
if self.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if self.adapter not in ("lora", "qlora", "qalora"):
raise ValueError(
"cfg.adapter must be lora, qlora, or qalora to use ReLoRA"
)

if self.fsdp or self.fsdp_config:
raise ValueError("fsdp not supported with ReLoRA")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_set_device_map_config(self):
# check torch_dtype
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]

@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
@pytest.mark.parametrize("adapter", ["lora", "qlora", "qalora", None])
@pytest.mark.parametrize("load_in_8bit", [True, False])
@pytest.mark.parametrize("load_in_4bit", [True, False])
@pytest.mark.parametrize("gptq", [True, False])
Expand Down
Loading