diff --git a/examples/llama-2/qalora.yml b/examples/llama-2/qalora.yml new file mode 100644 index 0000000000..ceca01b825 --- /dev/null +++ b/examples/llama-2/qalora.yml @@ -0,0 +1,48 @@ +base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/qalora-out + +# This will be the new adapter type once implemented +adapter: qalora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: false + +qalora_group_size: 16 + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_32bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: false + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 867e6901cb..2302319afa 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -89,6 +89,15 @@ def load_lora( if loftq_bits: lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) lora_config_kwargs["init_lora_weights"] = "loftq" + + if cfg.adapter == "qalora": + lora_config_kwargs["use_qalora"] = True + + if hasattr(cfg, "qalora_group_size") and cfg.qalora_group_size: + lora_config_kwargs["qalora_group_size"] = cfg.qalora_group_size + else: + ValueError("qalora_group_size must be set when using qalora") + if cfg.peft_init_lora_weights: lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights if cfg.peft_use_dora: @@ -170,7 +179,7 @@ def load_adapter( return model, None if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() - if adapter in ["lora", "qlora"]: + if adapter in ["lora", "qlora", "qalora"]: peft_model, lora_config = load_lora(model, cfg, inference=inference) return peft_model, lora_config if adapter == "llama-adapter": diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index a9507d685d..1185c02bb7 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -155,8 +155,8 @@ def is_fsdp_enabled(self): @property def is_qlora_and_fsdp_enabled(self): - """Property that determines if FSDP with QLoRA is enabled.""" - return self.is_fsdp_enabled and self.cfg.adapter == "qlora" + """Property that determines if FSDP with QLoRA/QALoRA is enabled.""" + return self.is_fsdp_enabled and self.cfg.adapter in ["qlora", "qalora"] def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: """Load and prepare the model with all configurations and patches. @@ -318,7 +318,7 @@ def _configure_embedding_dtypes(self): # Apply gradient checkpointing if needed needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled - if self.cfg.adapter in ["lora", "qlora"]: + if self.cfg.adapter in ["lora", "qlora", "qalora"]: needs_fa2_dtype = True if self.cfg.gradient_checkpointing: self.model.gradient_checkpointing_enable( @@ -533,7 +533,7 @@ def _set_quantization_config(self): **self.model_config.quantization_config ) if ( - self.cfg.adapter in ["qlora", "lora"] + self.cfg.adapter in ["qlora", "lora", "qalora"] and hasattr(self.model_config, "quantization_config") and self.model_config.quantization_config["quant_method"] in ["gptq", "awq", "bitsandbytes"] @@ -552,6 +552,14 @@ def _set_quantization_config(self): self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) + elif ( + self.cfg.adapter in ["qlora", "qalora"] + and self.model_kwargs["load_in_4bit"] + ): + quantization_config = getattr(self.model_config, "quantization_config", {}) + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **quantization_config + ) elif self.cfg.adapter == "qlora" and self.model_kwargs.get( "load_in_4bit", False ): @@ -859,7 +867,7 @@ def _prepare_model_for_quantization(self): if ( not skip_prepare_model_for_kbit_training - and self.cfg.adapter in ["lora", "qlora"] + and self.cfg.adapter in ["lora", "qlora", "qalora"] and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4d660d4b75..1b55a4e712 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -541,6 +541,12 @@ class AxolotlInputConfig( "description": "Whether to use flash-attention rms norm implementation - advanced use only" }, ) + flash_attn_fuse_qkv: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to fuse QKV projection into a single operation" + }, + ) flash_attn_fuse_mlp: bool | None = Field( default=None, json_schema_extra={ @@ -1073,7 +1079,7 @@ def check_auto_enable_lora_kernels(cls, data): if data.get("rl"): # RL trainers not tested so don't enable kernels by default return data - if data.get("adapter") in ["lora", "qlora"]: + if data.get("adapter") in ["lora", "qlora", "qalora"]: # Skip if already set, using unsloth optimizations, or using 8-bit unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 8f4718aa96..03debb033f 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -19,6 +19,14 @@ class TorchIntDType(Enum): int8 = getattr(torch, "int8", None) +class AdapterEnum(str, Enum): + """Adapter type configuration subset""" + + lora = "lora" + qlora = "qlora" + qalora = "qalora" + + class RLType(str, Enum): """RL trainer type configuration subset""" diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index de29521cb4..cf36ee82dc 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, Field, field_validator, model_validator +from axolotl.utils.schemas.enums import AdapterEnum + class LoftQConfig(BaseModel): """LoftQ configuration subset""" @@ -38,10 +40,10 @@ class LoraConfig(BaseModel): default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"} ) - adapter: str | None = Field( + adapter: AdapterEnum | None = Field( default=None, json_schema_extra={ - "description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model" + "description": "If you want to use 'lora' or 'qlora' or 'qalora' or leave blank to train all parameters in original model" }, ) lora_model_dir: str | None = Field( @@ -128,6 +130,10 @@ class LoraConfig(BaseModel): "description": "loraplus learning rate for lora embedding layers. Default value is 1e-6." }, ) + qalora_group_size: int | None = Field( + default=None, + json_schema_extra={"description": "Group size for QALoRA quantization pooling"}, + ) merge_lora: bool | None = None @@ -168,6 +174,28 @@ def validate_qlora(self): if not self.load_in_4bit: raise ValueError("Require cfg.load_in_4bit to be True for qlora") + + if self.adapter == "qalora": + if self.merge_lora: + # can't merge qalora if loaded in 8bit or 4bit + if self.load_in_8bit: + raise ValueError("Can't merge qalora if loaded in 8bit") + + if self.gptq: + raise ValueError("Can't merge qalora if gptq") + + if self.load_in_4bit: + raise ValueError("Can't merge qalora if loaded in 4bit") + + else: + if self.load_in_8bit: + raise ValueError("Can't load qalora in 8bit") + + if self.gptq: + raise ValueError("Can't load qalora if gptq") + + if not self.load_in_4bit: + raise ValueError("Require cfg.load_in_4bit to be True for qalora") return self @field_validator("loraplus_lr_embedding") diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 791894990c..1cb0a076f8 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -574,6 +574,19 @@ def check_lora_axolotl_unsloth(cls, data): @model_validator(mode="after") def check_fused_lora(self): + if self.adapter in ["lora", "qlora", "qalora"] and self.flash_attn_fuse_mlp: + raise ValueError("Fused modules are not supported with LoRA/QLoRA/QALoRA") + return self + + @model_validator(mode="after") + def validate_qalora(self): + if self.adapter == "qalora": + if not self.load_in_4bit: + raise ValueError("QALoRA requires load_in_4bit to be True") + if not hasattr(self, "qalora_group_size") or self.qalora_group_size is None: + raise ValueError("QALoRA requires qalora_group_size to be specified") + if self.merge_lora: + raise ValueError("QALoRA does not support merge_lora yet") if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp: raise ValueError("Fused modules are not supported with LoRA/QLoRA") return self @@ -582,7 +595,7 @@ def check_fused_lora(self): @classmethod def warn_qlora_zero3_w_use_reentrant(cls, data): if ( - data.get("adapter") == "qlora" + data.get("adapter") in ["qlora", "qalora"] and data.get("gradient_checkpointing_kwargs", {}) and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False @@ -688,7 +701,7 @@ def check_rl_config_gradient_checkpointing(cls, data): and data.get("gradient_checkpointing_kwargs") and data.get("gradient_checkpointing_kwargs").get("use_reentrant") and data.get("load_in_4bit") - and data.get("adapter") == "qlora" + and data.get("adapter") in ["qlora", "qalora"] and data.get("capabilities") and data.get("capabilities").get("n_gpu", 1) > 1 ): @@ -1186,8 +1199,10 @@ def check_relora(self): if self.relora: if not self.jagged_restart_steps: raise ValueError("jagged_restart_steps must be set to use ReLoRA") - if self.adapter not in ("lora", "qlora"): - raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") + if self.adapter not in ("lora", "qlora", "qalora"): + raise ValueError( + "cfg.adapter must be lora, qlora, or qalora to use ReLoRA" + ) if self.fsdp or self.fsdp_config: raise ValueError("fsdp not supported with ReLoRA") diff --git a/tests/test_loaders.py b/tests/test_loaders.py index f516d0ca4e..36d025491b 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -57,7 +57,7 @@ def test_set_device_map_config(self): # check torch_dtype assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"] - @pytest.mark.parametrize("adapter", ["lora", "qlora", None]) + @pytest.mark.parametrize("adapter", ["lora", "qlora", "qalora", None]) @pytest.mark.parametrize("load_in_8bit", [True, False]) @pytest.mark.parametrize("load_in_4bit", [True, False]) @pytest.mark.parametrize("gptq", [True, False])