diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index bcde4bf96e..54142365b6 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -19,6 +19,7 @@ ) from transformers import PreTrainedModel +from axolotl.loaders.adapters.builders.factory import AdapterBuilderFactory from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -177,19 +178,43 @@ def load_adapter( cfg: DictDefault, adapter: str | None, inference: bool = False, -) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]: - if adapter is None: - return model, None - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - if adapter in ["lora", "qlora"]: - peft_model, lora_config = load_lora(model, cfg, inference=inference) - return peft_model, lora_config - if adapter == "llama-adapter": - peft_model, lora_config = load_llama_adapter(model, cfg) - return peft_model, lora_config - - raise NotImplementedError(f"{adapter} PEFT adapter not available") + config_only: bool = False, +) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: + try: + if adapter is None: + return model, None + builder = AdapterBuilderFactory.create_builder(adapter, cfg) + + config = builder.build_config(model) + + if config_only: + return None, config + + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + + model = builder.build_model(model, config, inference=inference) + return model, config + + except ValueError as e: + LOG.debug( + f"Builder pattern failed, falling back to legacy adapter loading: {e}" + ) + + if adapter is None: + return model, None + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if adapter in ["lora", "qlora"]: + peft_model, lora_config = load_lora( + model, cfg, inference=inference, config_only=config_only + ) + return peft_model, lora_config + if adapter == "llama-adapter": + peft_model, lora_config = load_llama_adapter(model, cfg) + return peft_model, lora_config + + raise NotImplementedError(f"{adapter} PEFT adapter not available") from None def load_llama_adapter( diff --git a/src/axolotl/loaders/adapters/__init__.py b/src/axolotl/loaders/adapters/__init__.py index e69de29bb2..71315c3ef0 100644 --- a/src/axolotl/loaders/adapters/__init__.py +++ b/src/axolotl/loaders/adapters/__init__.py @@ -0,0 +1,13 @@ +"""Adapters package.""" + +from .builders import ( + AdapterBuilderFactory, + BaseAdapterBuilder, + LoraAdapterBuilder, +) + +__all__ = [ + "AdapterBuilderFactory", + "BaseAdapterBuilder", + "LoraAdapterBuilder", +] diff --git a/src/axolotl/loaders/adapters/builders/__init__.py b/src/axolotl/loaders/adapters/builders/__init__.py new file mode 100644 index 0000000000..588af11d8f --- /dev/null +++ b/src/axolotl/loaders/adapters/builders/__init__.py @@ -0,0 +1,11 @@ +"""Adapter builders package.""" + +from .base import BaseAdapterBuilder +from .factory import AdapterBuilderFactory +from .lora import LoraAdapterBuilder + +__all__ = [ + "BaseAdapterBuilder", + "AdapterBuilderFactory", + "LoraAdapterBuilder", +] diff --git a/src/axolotl/loaders/adapters/builders/base.py b/src/axolotl/loaders/adapters/builders/base.py new file mode 100644 index 0000000000..fa6e6a55c5 --- /dev/null +++ b/src/axolotl/loaders/adapters/builders/base.py @@ -0,0 +1,252 @@ +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union + +from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model +from transformers import PreTrainedModel + +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class BaseAdapterBuilder(ABC): + """Base class for adapter builders""" + + def __init__(self, cfg: DictDefault): + self.cfg = cfg + self.rank = int(os.environ.get("LOCAL_RANK", 0)) + + @abstractmethod + def build_config(self, model: PreTrainedModel, **kwargs) -> PeftConfig: + """Build the PEFT configuration""" + target_modules = self.prepare_target_modules(model) + target_parameters = self.prepare_target_parameters() + + config_kwargs = self.build_common_config_kwargs() + config_kwargs.update(kwargs) + + lora_config = LoraConfig( + r=self.cfg.lora_r, + lora_alpha=self.cfg.lora_alpha, + target_modules=target_modules, + target_parameters=target_parameters, + **config_kwargs, + ) + return lora_config + + @abstractmethod + def build_model( + self, model: PreTrainedModel, config: PeftConfig, *, inference: bool = False + ) -> PeftModel: + """Build the PEFT model""" + self.setup_quantization_for_training(model) + + if self.cfg.lora_model_dir: + model = self.load_pretrained_adapter(model, inference) + else: + model = self.create_peft_model(model, config) + + self.print_trainable_parameters(model) + self.setup_quantization_for_training_post_build(model) + + return model + + def prepare_target_modules( + self, + model: PreTrainedModel, + target_modules: Optional[Union[str, List[str]]] = None, + ) -> List[str]: + """ + Prepare and validate target modules for the adapter. + + Args: + model: The base model + target_modules: User-specified target modules + + Returns: + List[str]: Processed list of target modules + """ + + lora_target_modules: Union[str, List[str]] = ( + target_modules or self.cfg.lora_target_modules or [] + ) + + if self.cfg.lora_target_linear: + from axolotl.loaders.adapter import find_all_linear_names + + linear_names = find_all_linear_names(model) + LOG.info(f"found linear modules: {repr(sorted(linear_names))}") + lora_target_modules_as_list = ( + lora_target_modules + if isinstance(lora_target_modules, list) + else [lora_target_modules] + if lora_target_modules + else [] + ) + lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) + elif isinstance(lora_target_modules, str): + lora_target_modules = [lora_target_modules] + elif lora_target_modules is None: + lora_target_modules = [] + + return lora_target_modules + + def prepare_target_parameters( + self, target_parameters: Optional[Union[str, List[str]]] = None + ) -> List[str]: + """ + Prepare target parameters for the adapter. + + Args: + target_parameters: User-specified target parameters + + Returns: + List[str]: Processed list of target parameters + """ + result = target_parameters or self.cfg.lora_target_parameters or [] + if isinstance(result, str): + return [result] + elif isinstance(result, list): + return result + else: + return [] + + def build_common_config_kwargs(self) -> Dict[str, Any]: + """ + Build common configuration kwargs shared across adapter types. + + Returns: + Dict[str, Any]: Common configuration parameters + """ + config_kwargs = {} + + # LoftQ configuration + loftq_bits = ( + self.cfg.peft + and self.cfg.peft.loftq_config + and self.cfg.peft.loftq_config.loftq_bits + ) + if loftq_bits: + from peft import LoftQConfig + + config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) + config_kwargs["init_lora_weights"] = "loftq" + + # LoRA weight initialization + if self.cfg.peft_init_lora_weights: + config_kwargs["init_lora_weights"] = self.cfg.peft_init_lora_weights + + # DoRA configuration + if self.cfg.peft_use_dora: + config_kwargs["use_dora"] = self.cfg.peft_use_dora + LOG.info("Initializing LoRA weights using DoRA. This might take longer.") + + # RSLoRA configuration + if self.cfg.peft_use_rslora: + config_kwargs["use_rslora"] = self.cfg.peft_use_rslora + + # Layer replication + if self.cfg.peft_layer_replication: + config_kwargs["layer_replication"] = self.cfg.peft_layer_replication + + return config_kwargs + + def setup_quantization_for_training(self, model: Union[PreTrainedModel, PeftModel]): + """ + Setup quantization metadata for training. + + Args: + model: The model to setup quantization for + """ + from axolotl.loaders.adapter import setup_quantized_meta_for_peft + + if ( + self.cfg.fsdp_config + and self.cfg.adapter + and self.cfg.fsdp_config.cpu_ram_efficient_loading + and self.rank != 0 + ): + setup_quantized_meta_for_peft(model) + + def setup_quantization_for_training_post_build( + self, model: Union[PreTrainedModel, PeftModel] + ): + """ + Setup quantization metadata after model building for training. + + Args: + model: The model to setup quantization for + """ + from axolotl.loaders.adapter import setup_quantized_peft_meta_for_training + + if ( + self.cfg.fsdp_config + and self.cfg.adapter + and self.cfg.fsdp_config.cpu_ram_efficient_loading + and self.rank != 0 + ): + setup_quantized_peft_meta_for_training(model) + + def load_pretrained_adapter( + self, model: PreTrainedModel, inference: bool = False + ) -> Union[PreTrainedModel, PeftModel]: + """ + Load a pretrained adapter from a directory. + + Args: + model: Base model to load adapter onto + inference: Whether this is for inference mode + + Returns: + PeftModel: Model with loaded adapter + """ + + if not self.cfg.lora_model_dir: + return model + + LOG.debug(f"Loading pretrained PEFT - {self.__class__.__name__}") + model_kwargs: Dict[str, Any] = {} + + if self.cfg.lora_on_cpu: + model_kwargs["max_memory"] = {"cpu": "256GiB"} + model_kwargs["device_map"] = {"": "cpu"} + + return PeftModel.from_pretrained( + model, + self.cfg.lora_model_dir, + is_trainable=(not inference), + **model_kwargs, + ) + + def create_peft_model( + self, model: PreTrainedModel, config: PeftConfig + ) -> PeftModel: + """ + Create a PEFT model from base model and config. + + Args: + model: Base model + config: PEFT configuration + + Returns: + PeftModel: Created PEFT model + """ + return get_peft_model(model, config) + + def print_trainable_parameters(self, model: Union[PreTrainedModel, PeftModel]): + """ + Print the number of trainable parameters in the model. + + Args: + model: The model to analyze + """ + if self.rank == 0: + try: + model.print_trainable_parameters() + except AttributeError as exc: + LOG.warning( + "Exception caught during model.print_trainable_parameters(): %s", + exc, + ) diff --git a/src/axolotl/loaders/adapters/builders/factory.py b/src/axolotl/loaders/adapters/builders/factory.py new file mode 100644 index 0000000000..da15a4b235 --- /dev/null +++ b/src/axolotl/loaders/adapters/builders/factory.py @@ -0,0 +1,70 @@ +from typing import Dict, Type + +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +from .base import BaseAdapterBuilder +from .lora import LoraAdapterBuilder + +LOG = get_logger(__name__) + + +class AdapterBuilderFactory: + """Factory for creating adapter builders based on adapter type.""" + + _builders: Dict[str, Type[BaseAdapterBuilder]] = { + "lora": LoraAdapterBuilder, + "qlora": LoraAdapterBuilder, + } + + @classmethod + def register_builder( + cls, adapter_type: str, builder_class: Type[BaseAdapterBuilder] + ): + """ + Register a new adapter builder. + + Args: + adapter_type: Type of adapter (e.g., 'lora', 'qlora') + builder_class: Builder class that extends BaseAdapterBuilder + """ + cls._builders[adapter_type] = builder_class + LOG.info( + f"Registered adapter builder for '{adapter_type}': {builder_class.__name__}" + ) + + @classmethod + def create_builder(cls, adapter_type: str, cfg: DictDefault) -> BaseAdapterBuilder: + """ + Create an adapter builder for the specified type. + + Args: + adapter_type: Type of adapter to create builder for + cfg: Configuration object + + Returns: + BaseAdapterBuilder: Configured adapter builder + + Raises: + ValueError: If adapter type is not supported + """ + if adapter_type not in cls._builders: + available_types = list(cls._builders.keys()) + raise ValueError( + f"Unsupported adapter type: {adapter_type}. " + f"Available types: {available_types}" + ) + + builder_class = cls._builders[adapter_type] + LOG.info(f"Creating {builder_class.__name__} for adapter type '{adapter_type}'") + return builder_class(cfg) + + @classmethod + def get_supported_adapters(cls) -> list[str]: + """ + Get list of supported adapter types. + + Returns: + list[str]: List of supported adapter type names + """ + return list(cls._builders.keys()) diff --git a/src/axolotl/loaders/adapters/builders/lora.py b/src/axolotl/loaders/adapters/builders/lora.py new file mode 100644 index 0000000000..25749dfe90 --- /dev/null +++ b/src/axolotl/loaders/adapters/builders/lora.py @@ -0,0 +1,74 @@ +from peft import LoraConfig, PeftModel +from transformers import PreTrainedModel + +from axolotl.utils.logging import get_logger + +from .base import BaseAdapterBuilder + +LOG = get_logger(__name__) + + +class LoraAdapterBuilder(BaseAdapterBuilder): + """Builder for LoRA adapters.""" + + def build_config(self, model: PreTrainedModel, **kwargs) -> LoraConfig: + """ + Build LoRA configuration. + + Args: + model: The base model + **kwargs: Additional configuration options + + Returns: + LoraConfig: Configured LoRA adapter + """ + target_modules = self.prepare_target_modules(model) + target_parameters = self.prepare_target_parameters() + + config_kwargs = self.build_common_config_kwargs() + + config_kwargs.update(kwargs) + + lora_config = LoraConfig( + r=self.cfg.lora_r, + lora_alpha=self.cfg.lora_alpha, + target_modules=target_modules, + target_parameters=target_parameters, + layers_to_transform=self.cfg.peft_layers_to_transform, + layers_pattern=self.cfg.peft_layers_pattern, + lora_dropout=self.cfg.lora_dropout, + fan_in_fan_out=self.cfg.lora_fan_in_fan_out, + modules_to_save=self.cfg.lora_modules_to_save + if self.cfg.lora_modules_to_save + else None, + bias="none", + task_type="CAUSAL_LM", + **config_kwargs, + ) + return lora_config + + def build_model( + self, model: PreTrainedModel, config: LoraConfig, *, inference: bool = False + ) -> PeftModel: + """ + Build LoRA model. + + Args: + model: Base model + config: LoRA configuration + + Returns: + PeftModel: Model with LoRA adapter applied + """ + self.setup_quantization_for_training(model) + + if self.cfg.lora_model_dir: + LOG.debug("Loading pretrained PEFT - LoRA") + model = self.load_pretrained_adapter(model, inference) + else: + model = self.create_peft_model(model, config) + + self.print_trainable_parameters(model) + self.setup_quantization_for_training_post_build(model) + + return model diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index f438d6b61a..6ad3203d64 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -41,6 +41,7 @@ from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.integrations.base import PluginManager from axolotl.loaders.adapter import load_adapter, load_lora +from axolotl.loaders.adapters.builders.factory import AdapterBuilderFactory from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING from axolotl.loaders.patch_manager import PatchManager from axolotl.loaders.utils import ( @@ -375,12 +376,32 @@ def _load_adapters(self) -> PeftConfig | None: and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO] and not self.cfg.merge_lora ): - _, lora_config = load_lora( - self.model, self.cfg, inference=False, config_only=True - ) + try: + builder = AdapterBuilderFactory.create_builder( + self.cfg.adapter, self.cfg + ) + lora_config = builder.build_config(self.model) + except ( + ValueError, + ImportError, + TypeError, + AttributeError, + RuntimeError, + ) as e: + if self.cfg.adapter in ["lora", "qlora"]: + # Fallback to legacy method + LOG.debug( + f"Builder pattern failed for config-only, falling back to legacy: {e}" + ) + _, lora_config = load_lora( + self.model, self.cfg, inference=False, config_only=True + ) + else: + # Re-raise the original exception for non-LoRA adapters + raise else: self.model, lora_config = load_adapter( - self.model, self.cfg, self.cfg.adapter + self.model, self.cfg, self.cfg.adapter, inference=self.inference ) return lora_config