From 8b345322edc0539bd2ba7f492ed6057654a55529 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Mon, 4 Aug 2025 19:18:00 +0530 Subject: [PATCH 01/13] qa-lora integration --- src/axolotl/core/builders/causal.py | 3 +++ src/axolotl/core/training_args_base.py | 8 ++++++ src/axolotl/loaders/adapter.py | 9 ++++++- src/axolotl/loaders/model.py | 12 ++++----- src/axolotl/utils/schemas/config.py | 2 +- src/axolotl/utils/schemas/peft.py | 36 ++++++++++++++++++++++++- src/axolotl/utils/schemas/validation.py | 21 +++++++++++---- tests/test_loaders.py | 2 +- 8 files changed, 78 insertions(+), 15 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index b461e90092..8fb5569ee7 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -151,6 +151,9 @@ def build(self, total_num_steps): if self.cfg.adapter == "qlora": training_arguments_kwargs["qlora"] = True + if self.cfg.adapter == "qalora": + training_arguments_kwargs["qalora"] = True + # deepspeed if self.cfg.deepspeed: training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 66649deefd..71d2d041ab 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -163,6 +163,14 @@ class AxolotlTrainingMixins: default=False, metadata={"help": "whether this is a qlora training"}, ) + qalora: bool = field( + default=False, + metadata={"help": "whether this is a qalora training"}, + ) + qalora_group_size: Optional[int] = field( + default=16, + metadata={"help": "Group size for QALoRA quantization"}, + ) orpo_alpha: Optional[float] = field( default=None, ) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index db28206b6c..cb3de4a8c6 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -93,6 +93,13 @@ def load_lora( if loftq_bits: lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) lora_config_kwargs["init_lora_weights"] = "loftq" + + if cfg.adapter == "qalora": + if hasattr(cfg, 'use_qalora') and cfg.use_qalora: + lora_config_kwargs["use_qalora"] = True + if hasattr(cfg, 'qalora_group_size') and cfg.qalora_group_size: + lora_config_kwargs["qalora_group_size"] = cfg.qalora_group_size + if cfg.peft_init_lora_weights: lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights if cfg.peft_use_dora: @@ -174,7 +181,7 @@ def load_adapter( return model, None if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() - if adapter in ["lora", "qlora"]: + if adapter in ["lora", "qlora" , "qalora"]: peft_model, lora_config = load_lora(model, cfg, inference=inference) return peft_model, lora_config if adapter == "llama-adapter": diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 05039c9ee9..6c4ef49ac5 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -148,8 +148,8 @@ def is_fsdp_enabled(self): @property def is_qlora_and_fsdp_enabled(self): - """Property that determines if FSDP with QLoRA is enabled.""" - return self.is_fsdp_enabled and self.cfg.adapter == "qlora" + """Property that determines if FSDP with QLoRA/QALoRA is enabled.""" + return self.is_fsdp_enabled and self.cfg.adapter in ["qlora", "qalora"] def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: """Load and prepare the model with all configurations and patches. @@ -305,7 +305,7 @@ def _configure_embedding_dtypes(self): # Apply gradient checkpointing if needed needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled - if self.cfg.adapter in ["lora", "qlora"]: + if self.cfg.adapter in ["lora", "qlora", "qalora"]: needs_fa2_dtype = True if self.cfg.gradient_checkpointing: self.model.gradient_checkpointing_enable( @@ -582,7 +582,7 @@ def _set_quantization_config(self): **self.model_config.quantization_config ) if ( - self.cfg.adapter in ["qlora", "lora"] + self.cfg.adapter in ["qlora", "lora", "qalora"] and hasattr(self.model_config, "quantization_config") and self.model_config.quantization_config["quant_method"] in ["gptq", "awq", "bitsandbytes"] @@ -601,7 +601,7 @@ def _set_quantization_config(self): self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) - elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: + elif self.cfg.adapter in ["qlora", "qalora"] and self.model_kwargs["load_in_4bit"]: bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -886,7 +886,7 @@ def _prepare_model_for_quantization(self): if ( not skip_prepare_model_for_kbit_training - and self.cfg.adapter in ["lora", "qlora"] + and self.cfg.adapter in ["lora", "qlora" , "qalora"] and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 1d089ba41f..415fb6c411 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1064,7 +1064,7 @@ def check_auto_enable_lora_kernels(cls, data): if data.get("rl"): # RL trainers not tested so don't enable kernels by default return data - if data.get("adapter") in ["lora", "qlora"]: + if data.get("adapter") in ["lora", "qlora", "qalora"]: # Skip if already set, using unsloth optimizations, or using 8-bit unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index de29521cb4..d7815078cb 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -41,7 +41,7 @@ class LoraConfig(BaseModel): adapter: str | None = Field( default=None, json_schema_extra={ - "description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model" + "description": "If you want to use 'lora' or 'qlora' or 'qalora' or leave blank to train all parameters in original model" }, ) lora_model_dir: str | None = Field( @@ -128,6 +128,18 @@ class LoraConfig(BaseModel): "description": "loraplus learning rate for lora embedding layers. Default value is 1e-6." }, ) + use_qalora: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Enable Quantization-Aware Low-Rank Adaptation" + }, + ) + qalora_group_size: int | None = Field( + default=16, + json_schema_extra={ + "description": "Group size for QALoRA quantization pooling" + }, + ) merge_lora: bool | None = None @@ -168,6 +180,28 @@ def validate_qlora(self): if not self.load_in_4bit: raise ValueError("Require cfg.load_in_4bit to be True for qlora") + + if self.adapter == "qalora": + if self.merge_lora: + # can't merge qalora if loaded in 8bit or 4bit + if self.load_in_8bit: + raise ValueError("Can't merge qalora if loaded in 8bit") + + if self.gptq: + raise ValueError("Can't merge qalora if gptq") + + if self.load_in_4bit: + raise ValueError("Can't merge qalora if loaded in 4bit") + + else: + if self.load_in_8bit: + raise ValueError("Can't load qalora in 8bit") + + if self.gptq: + raise ValueError("Can't load qalora if gptq") + + if not self.load_in_4bit: + raise ValueError("Require cfg.load_in_4bit to be True for qalora") return self @field_validator("loraplus_lr_embedding") diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 02e80dd8e0..421fcd5dd0 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -591,17 +591,28 @@ def check_lora_axolotl_unsloth(cls, data): @model_validator(mode="after") def check_fused_lora(self): - if self.adapter in ["lora", "qlora"] and ( + if self.adapter in ["lora", "qlora", "qalora"] and ( self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp ): raise ValueError("Fused modules are not supported with LoRA/QLoRA") return self + + @model_validator(mode="after") + def validate_qalora(self): + if self.adapter == "qalora": + if not self.load_in_4bit: + raise ValueError("QALoRA requires load_in_4bit to be True") + if not hasattr(self, 'qalora_group_size') or self.qalora_group_size is None: + raise ValueError("QALoRA requires qalora_group_size to be specified") + if self.merge_lora: + raise ValueError("QALoRA does not support merge_lora yet") + return self @model_validator(mode="before") @classmethod def warn_qlora_zero3_w_use_reentrant(cls, data): if ( - data.get("adapter") == "qlora" + data.get("adapter") in ["qlora", "qalora"] and data.get("gradient_checkpointing_kwargs", {}) and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False @@ -691,7 +702,7 @@ def check_rl_config_gradient_checkpointing(cls, data): and data.get("gradient_checkpointing_kwargs") and data.get("gradient_checkpointing_kwargs").get("use_reentrant") and data.get("load_in_4bit") - and data.get("adapter") == "qlora" + and data.get("adapter") in ["qlora", "qalora"] and data.get("capabilities") and data.get("capabilities").get("n_gpu", 1) > 1 ): @@ -1168,8 +1179,8 @@ def check_relora(self): if self.relora: if not self.jagged_restart_steps: raise ValueError("jagged_restart_steps must be set to use ReLoRA") - if self.adapter not in ("lora", "qlora"): - raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") + if self.adapter not in ("lora", "qlora", "qalora"): + raise ValueError("cfg.adapter must be lora, qlora, or qalora to use ReLoRA") if self.fsdp or self.fsdp_config: raise ValueError("fsdp not supported with ReLoRA") diff --git a/tests/test_loaders.py b/tests/test_loaders.py index def7672b97..991c9e8ec9 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -60,7 +60,7 @@ def test_set_device_map_config(self): # check torch_dtype assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"] - @pytest.mark.parametrize("adapter", ["lora", "qlora", None]) + @pytest.mark.parametrize("adapter", ["lora", "qlora", "qalora", None]) @pytest.mark.parametrize("load_in_8bit", [True, False]) @pytest.mark.parametrize("load_in_4bit", [True, False]) @pytest.mark.parametrize("gptq", [True, False]) From 3104c6e4d351cf565e3eedcbe25c386c6024991d Mon Sep 17 00:00:00 2001 From: ved1beta Date: Mon, 4 Aug 2025 19:19:41 +0530 Subject: [PATCH 02/13] lint + yml --- examples/llama-2/qalora.yml | 49 +++++++++++++++++++++++++ src/axolotl/loaders/adapter.py | 6 +-- src/axolotl/loaders/model.py | 7 +++- src/axolotl/utils/schemas/peft.py | 6 +-- src/axolotl/utils/schemas/validation.py | 8 ++-- 5 files changed, 64 insertions(+), 12 deletions(-) create mode 100644 examples/llama-2/qalora.yml diff --git a/examples/llama-2/qalora.yml b/examples/llama-2/qalora.yml new file mode 100644 index 0000000000..ae27f6257b --- /dev/null +++ b/examples/llama-2/qalora.yml @@ -0,0 +1,49 @@ +base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/qalora-out + +# This will be the new adapter type once implemented +adapter: qalora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: false + +qlora: true +qalora_group_size: 16 + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_32bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: false + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: \ No newline at end of file diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index cb3de4a8c6..c867977aa9 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -95,9 +95,9 @@ def load_lora( lora_config_kwargs["init_lora_weights"] = "loftq" if cfg.adapter == "qalora": - if hasattr(cfg, 'use_qalora') and cfg.use_qalora: + if hasattr(cfg, "use_qalora") and cfg.use_qalora: lora_config_kwargs["use_qalora"] = True - if hasattr(cfg, 'qalora_group_size') and cfg.qalora_group_size: + if hasattr(cfg, "qalora_group_size") and cfg.qalora_group_size: lora_config_kwargs["qalora_group_size"] = cfg.qalora_group_size if cfg.peft_init_lora_weights: @@ -181,7 +181,7 @@ def load_adapter( return model, None if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() - if adapter in ["lora", "qlora" , "qalora"]: + if adapter in ["lora", "qlora", "qalora"]: peft_model, lora_config = load_lora(model, cfg, inference=inference) return peft_model, lora_config if adapter == "llama-adapter": diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 6c4ef49ac5..4c94a67827 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -601,7 +601,10 @@ def _set_quantization_config(self): self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) - elif self.cfg.adapter in ["qlora", "qalora"] and self.model_kwargs["load_in_4bit"]: + elif ( + self.cfg.adapter in ["qlora", "qalora"] + and self.model_kwargs["load_in_4bit"] + ): bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -886,7 +889,7 @@ def _prepare_model_for_quantization(self): if ( not skip_prepare_model_for_kbit_training - and self.cfg.adapter in ["lora", "qlora" , "qalora"] + and self.cfg.adapter in ["lora", "qlora", "qalora"] and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index d7815078cb..12acda9d7f 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -136,9 +136,7 @@ class LoraConfig(BaseModel): ) qalora_group_size: int | None = Field( default=16, - json_schema_extra={ - "description": "Group size for QALoRA quantization pooling" - }, + json_schema_extra={"description": "Group size for QALoRA quantization pooling"}, ) merge_lora: bool | None = None @@ -180,7 +178,7 @@ def validate_qlora(self): if not self.load_in_4bit: raise ValueError("Require cfg.load_in_4bit to be True for qlora") - + if self.adapter == "qalora": if self.merge_lora: # can't merge qalora if loaded in 8bit or 4bit diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 421fcd5dd0..db5d538bd2 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -596,13 +596,13 @@ def check_fused_lora(self): ): raise ValueError("Fused modules are not supported with LoRA/QLoRA") return self - + @model_validator(mode="after") def validate_qalora(self): if self.adapter == "qalora": if not self.load_in_4bit: raise ValueError("QALoRA requires load_in_4bit to be True") - if not hasattr(self, 'qalora_group_size') or self.qalora_group_size is None: + if not hasattr(self, "qalora_group_size") or self.qalora_group_size is None: raise ValueError("QALoRA requires qalora_group_size to be specified") if self.merge_lora: raise ValueError("QALoRA does not support merge_lora yet") @@ -1180,7 +1180,9 @@ def check_relora(self): if not self.jagged_restart_steps: raise ValueError("jagged_restart_steps must be set to use ReLoRA") if self.adapter not in ("lora", "qlora", "qalora"): - raise ValueError("cfg.adapter must be lora, qlora, or qalora to use ReLoRA") + raise ValueError( + "cfg.adapter must be lora, qlora, or qalora to use ReLoRA" + ) if self.fsdp or self.fsdp_config: raise ValueError("fsdp not supported with ReLoRA") From 573992bcd1f46cd7400f59e1c47deb1a8c554850 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Aug 2025 11:45:44 -0400 Subject: [PATCH 03/13] chore: lint --- examples/llama-2/qalora.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama-2/qalora.yml b/examples/llama-2/qalora.yml index ae27f6257b..7c2bfc2675 100644 --- a/examples/llama-2/qalora.yml +++ b/examples/llama-2/qalora.yml @@ -46,4 +46,4 @@ warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 -special_tokens: \ No newline at end of file +special_tokens: From 81995cc4c735bf644f880f2ada1a17a694b2b642 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:34:06 +0530 Subject: [PATCH 04/13] Update src/axolotl/utils/schemas/peft.py Co-authored-by: NanoCode012 --- src/axolotl/utils/schemas/peft.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 12acda9d7f..c46755d9db 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -128,12 +128,6 @@ class LoraConfig(BaseModel): "description": "loraplus learning rate for lora embedding layers. Default value is 1e-6." }, ) - use_qalora: bool | None = Field( - default=False, - json_schema_extra={ - "description": "Enable Quantization-Aware Low-Rank Adaptation" - }, - ) qalora_group_size: int | None = Field( default=16, json_schema_extra={"description": "Group size for QALoRA quantization pooling"}, From 0a9d202204f786854d923589aeb9286f247745f5 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:34:15 +0530 Subject: [PATCH 05/13] Update src/axolotl/utils/schemas/peft.py Co-authored-by: NanoCode012 --- src/axolotl/utils/schemas/peft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index c46755d9db..fb3c313181 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -129,7 +129,7 @@ class LoraConfig(BaseModel): }, ) qalora_group_size: int | None = Field( - default=16, + default=None, json_schema_extra={"description": "Group size for QALoRA quantization pooling"}, ) From 2b0470fdc25e91d701984c709c0d69ed00858392 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:34:40 +0530 Subject: [PATCH 06/13] Update src/axolotl/utils/schemas/validation.py Co-authored-by: NanoCode012 --- src/axolotl/utils/schemas/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index db5d538bd2..7131efd03e 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -594,7 +594,7 @@ def check_fused_lora(self): if self.adapter in ["lora", "qlora", "qalora"] and ( self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp ): - raise ValueError("Fused modules are not supported with LoRA/QLoRA") + raise ValueError("Fused modules are not supported with LoRA/QLoRA/QALoRA") return self @model_validator(mode="after") From ab40a6459f74c77e1a1b4f74acf51e7b6349cb5b Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:35:05 +0530 Subject: [PATCH 07/13] Update examples/llama-2/qalora.yml Co-authored-by: NanoCode012 --- examples/llama-2/qalora.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llama-2/qalora.yml b/examples/llama-2/qalora.yml index 7c2bfc2675..ceca01b825 100644 --- a/examples/llama-2/qalora.yml +++ b/examples/llama-2/qalora.yml @@ -20,7 +20,6 @@ sequence_len: 2048 sample_packing: true eval_sample_packing: false -qlora: true qalora_group_size: 16 lora_r: 32 From da95f8445022adca22510bfa535bccbecbd6021b Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:35:38 +0530 Subject: [PATCH 08/13] Update src/axolotl/core/training_args_base.py Co-authored-by: NanoCode012 --- src/axolotl/core/training_args_base.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 71d2d041ab..66649deefd 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -163,14 +163,6 @@ class AxolotlTrainingMixins: default=False, metadata={"help": "whether this is a qlora training"}, ) - qalora: bool = field( - default=False, - metadata={"help": "whether this is a qalora training"}, - ) - qalora_group_size: Optional[int] = field( - default=16, - metadata={"help": "Group size for QALoRA quantization"}, - ) orpo_alpha: Optional[float] = field( default=None, ) From 867e35f47cc5ab9eb73b1a13eb848621c2bb348b Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:35:56 +0530 Subject: [PATCH 09/13] Update src/axolotl/core/builders/causal.py Co-authored-by: NanoCode012 --- src/axolotl/core/builders/causal.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 8fb5569ee7..4d432512e0 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -151,8 +151,6 @@ def build(self, total_num_steps): if self.cfg.adapter == "qlora": training_arguments_kwargs["qlora"] = True - if self.cfg.adapter == "qalora": - training_arguments_kwargs["qalora"] = True # deepspeed if self.cfg.deepspeed: From 04cc91ee032758df7f83e54a5e0b71754d8396cc Mon Sep 17 00:00:00 2001 From: ved1beta Date: Mon, 25 Aug 2025 13:19:07 +0530 Subject: [PATCH 10/13] add enum,warn qalora_group_size --- src/axolotl/loaders/adapter.py | 6 ++++-- src/axolotl/utils/schemas/enums.py | 8 ++++++++ src/axolotl/utils/schemas/peft.py | 4 +++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index c867977aa9..5cd9e8662d 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -95,10 +95,12 @@ def load_lora( lora_config_kwargs["init_lora_weights"] = "loftq" if cfg.adapter == "qalora": - if hasattr(cfg, "use_qalora") and cfg.use_qalora: - lora_config_kwargs["use_qalora"] = True + lora_config_kwargs["use_qalora"] = True + if hasattr(cfg, "qalora_group_size") and cfg.qalora_group_size: lora_config_kwargs["qalora_group_size"] = cfg.qalora_group_size + else: + ValueError("qalora_group_size must be set when using qalora") if cfg.peft_init_lora_weights: lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 3c88283962..434d77e370 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -21,6 +21,14 @@ class TorchIntDType(Enum): int8 = getattr(torch, "int8", None) +class AdapterEnum(str, Enum): + """Adapter type configuration subset""" + + lora = "lora" + qlora = "qlora" + qalora = "qalora" + + class RLType(str, Enum): """RL trainer type configuration subset""" diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 12acda9d7f..1fb5518786 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, Field, field_validator, model_validator +from axolotl.utils.schemas.enums import AdapterEnum + class LoftQConfig(BaseModel): """LoftQ configuration subset""" @@ -38,7 +40,7 @@ class LoraConfig(BaseModel): default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"} ) - adapter: str | None = Field( + adapter: AdapterEnum | None = Field( default=None, json_schema_extra={ "description": "If you want to use 'lora' or 'qlora' or 'qalora' or leave blank to train all parameters in original model" From f8baf32460f4f1c45e72669c710dedcf94114092 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Mon, 25 Aug 2025 13:25:56 +0530 Subject: [PATCH 11/13] ruff chore --- src/axolotl/core/builders/causal.py | 1 - src/axolotl/loaders/model.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index de3a885f87..94b0db8515 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -164,7 +164,6 @@ def build(self, total_num_steps): if self.cfg.adapter == "qlora": training_arguments_kwargs["qlora"] = True - # deepspeed if self.cfg.deepspeed: training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 925b2c1f15..c316b1a3d0 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -555,6 +555,10 @@ def _set_quantization_config(self): elif ( self.cfg.adapter in ["qlora", "qalora"] and self.model_kwargs["load_in_4bit"] + ): + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **self.model_config.quantization_config + ) elif self.cfg.adapter == "qlora" and self.model_kwargs.get( "load_in_4bit", False ): From 0170f386cb2335cac2b48f1e78412bfa80a13474 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Mon, 25 Aug 2025 16:27:54 +0530 Subject: [PATCH 12/13] test fix --- src/axolotl/loaders/model.py | 3 ++- src/axolotl/utils/schemas/config.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index c316b1a3d0..1185c02bb7 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -556,8 +556,9 @@ def _set_quantization_config(self): self.cfg.adapter in ["qlora", "qalora"] and self.model_kwargs["load_in_4bit"] ): + quantization_config = getattr(self.model_config, "quantization_config", {}) self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **self.model_config.quantization_config + **quantization_config ) elif self.cfg.adapter == "qlora" and self.model_kwargs.get( "load_in_4bit", False diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 1004a08d21..1b55a4e712 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -541,6 +541,12 @@ class AxolotlInputConfig( "description": "Whether to use flash-attention rms norm implementation - advanced use only" }, ) + flash_attn_fuse_qkv: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to fuse QKV projection into a single operation" + }, + ) flash_attn_fuse_mlp: bool | None = Field( default=None, json_schema_extra={ From d3fcc44417bceabd58bafd23214a0e77f9b1f4e3 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Wed, 27 Aug 2025 00:00:52 +0530 Subject: [PATCH 13/13] merge conflict check --- src/axolotl/utils/schemas/validation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 9b88ef852b..1cb0a076f8 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -574,9 +574,7 @@ def check_lora_axolotl_unsloth(cls, data): @model_validator(mode="after") def check_fused_lora(self): - if self.adapter in ["lora", "qlora", "qalora"] and ( - self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp - ): + if self.adapter in ["lora", "qlora", "qalora"] and self.flash_attn_fuse_mlp: raise ValueError("Fused modules are not supported with LoRA/QLoRA/QALoRA") return self