Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions src/axolotl/loaders/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from transformers import PreTrainedModel

from axolotl.loaders.adapters.builders.factory import AdapterBuilderFactory
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
Expand Down Expand Up @@ -177,19 +178,43 @@ def load_adapter(
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
peft_model, lora_config = load_lora(model, cfg, inference=inference)
return peft_model, lora_config
if adapter == "llama-adapter":
peft_model, lora_config = load_llama_adapter(model, cfg)
return peft_model, lora_config

raise NotImplementedError(f"{adapter} PEFT adapter not available")
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
try:
if adapter is None:
return model, None
builder = AdapterBuilderFactory.create_builder(adapter, cfg)

config = builder.build_config(model)

if config_only:
return None, config

if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()

model = builder.build_model(model, config, inference=inference)
return model, config

except ValueError as e:
LOG.debug(
f"Builder pattern failed, falling back to legacy adapter loading: {e}"
)

if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
peft_model, lora_config = load_lora(
model, cfg, inference=inference, config_only=config_only
)
return peft_model, lora_config
if adapter == "llama-adapter":
peft_model, lora_config = load_llama_adapter(model, cfg)
return peft_model, lora_config

raise NotImplementedError(f"{adapter} PEFT adapter not available") from None


def load_llama_adapter(
Expand Down
13 changes: 13 additions & 0 deletions src/axolotl/loaders/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Adapters package."""

from .builders import (
AdapterBuilderFactory,
BaseAdapterBuilder,
LoraAdapterBuilder,
)

__all__ = [
"AdapterBuilderFactory",
"BaseAdapterBuilder",
"LoraAdapterBuilder",
]
11 changes: 11 additions & 0 deletions src/axolotl/loaders/adapters/builders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Adapter builders package."""

from .base import BaseAdapterBuilder
from .factory import AdapterBuilderFactory
from .lora import LoraAdapterBuilder

__all__ = [
"BaseAdapterBuilder",
"AdapterBuilderFactory",
"LoraAdapterBuilder",
]
252 changes: 252 additions & 0 deletions src/axolotl/loaders/adapters/builders/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from transformers import PreTrainedModel

from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


class BaseAdapterBuilder(ABC):
"""Base class for adapter builders"""

def __init__(self, cfg: DictDefault):
self.cfg = cfg
self.rank = int(os.environ.get("LOCAL_RANK", 0))

@abstractmethod
def build_config(self, model: PreTrainedModel, **kwargs) -> PeftConfig:
"""Build the PEFT configuration"""
target_modules = self.prepare_target_modules(model)
target_parameters = self.prepare_target_parameters()

config_kwargs = self.build_common_config_kwargs()
config_kwargs.update(kwargs)

lora_config = LoraConfig(
r=self.cfg.lora_r,
lora_alpha=self.cfg.lora_alpha,
target_modules=target_modules,
target_parameters=target_parameters,
**config_kwargs,
)
return lora_config

@abstractmethod
def build_model(
self, model: PreTrainedModel, config: PeftConfig, *, inference: bool = False
) -> PeftModel:
"""Build the PEFT model"""
self.setup_quantization_for_training(model)

if self.cfg.lora_model_dir:
model = self.load_pretrained_adapter(model, inference)
else:
model = self.create_peft_model(model, config)

self.print_trainable_parameters(model)
self.setup_quantization_for_training_post_build(model)

return model

def prepare_target_modules(
self,
model: PreTrainedModel,
target_modules: Optional[Union[str, List[str]]] = None,
) -> List[str]:
"""
Prepare and validate target modules for the adapter.

Args:
model: The base model
target_modules: User-specified target modules

Returns:
List[str]: Processed list of target modules
"""

lora_target_modules: Union[str, List[str]] = (
target_modules or self.cfg.lora_target_modules or []
)

if self.cfg.lora_target_linear:
from axolotl.loaders.adapter import find_all_linear_names

linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
if lora_target_modules
else []
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
elif isinstance(lora_target_modules, str):
lora_target_modules = [lora_target_modules]
elif lora_target_modules is None:
lora_target_modules = []

return lora_target_modules

def prepare_target_parameters(
self, target_parameters: Optional[Union[str, List[str]]] = None
) -> List[str]:
"""
Prepare target parameters for the adapter.

Args:
target_parameters: User-specified target parameters

Returns:
List[str]: Processed list of target parameters
"""
result = target_parameters or self.cfg.lora_target_parameters or []
if isinstance(result, str):
return [result]
elif isinstance(result, list):
return result
else:
return []

def build_common_config_kwargs(self) -> Dict[str, Any]:
"""
Build common configuration kwargs shared across adapter types.

Returns:
Dict[str, Any]: Common configuration parameters
"""
config_kwargs = {}

# LoftQ configuration
loftq_bits = (
self.cfg.peft
and self.cfg.peft.loftq_config
and self.cfg.peft.loftq_config.loftq_bits
)
if loftq_bits:
from peft import LoftQConfig

config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
config_kwargs["init_lora_weights"] = "loftq"

# LoRA weight initialization
if self.cfg.peft_init_lora_weights:
config_kwargs["init_lora_weights"] = self.cfg.peft_init_lora_weights

# DoRA configuration
if self.cfg.peft_use_dora:
config_kwargs["use_dora"] = self.cfg.peft_use_dora
LOG.info("Initializing LoRA weights using DoRA. This might take longer.")

# RSLoRA configuration
if self.cfg.peft_use_rslora:
config_kwargs["use_rslora"] = self.cfg.peft_use_rslora

# Layer replication
if self.cfg.peft_layer_replication:
config_kwargs["layer_replication"] = self.cfg.peft_layer_replication

return config_kwargs

def setup_quantization_for_training(self, model: Union[PreTrainedModel, PeftModel]):
"""
Setup quantization metadata for training.

Args:
model: The model to setup quantization for
"""
from axolotl.loaders.adapter import setup_quantized_meta_for_peft

if (
self.cfg.fsdp_config
and self.cfg.adapter
and self.cfg.fsdp_config.cpu_ram_efficient_loading
and self.rank != 0
):
setup_quantized_meta_for_peft(model)

def setup_quantization_for_training_post_build(
self, model: Union[PreTrainedModel, PeftModel]
):
"""
Setup quantization metadata after model building for training.

Args:
model: The model to setup quantization for
"""
from axolotl.loaders.adapter import setup_quantized_peft_meta_for_training

if (
self.cfg.fsdp_config
and self.cfg.adapter
and self.cfg.fsdp_config.cpu_ram_efficient_loading
and self.rank != 0
):
setup_quantized_peft_meta_for_training(model)

def load_pretrained_adapter(
self, model: PreTrainedModel, inference: bool = False
) -> Union[PreTrainedModel, PeftModel]:
"""
Load a pretrained adapter from a directory.

Args:
model: Base model to load adapter onto
inference: Whether this is for inference mode

Returns:
PeftModel: Model with loaded adapter
"""

if not self.cfg.lora_model_dir:
return model

LOG.debug(f"Loading pretrained PEFT - {self.__class__.__name__}")
model_kwargs: Dict[str, Any] = {}

if self.cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}

return PeftModel.from_pretrained(
model,
self.cfg.lora_model_dir,
is_trainable=(not inference),
**model_kwargs,
)

def create_peft_model(
self, model: PreTrainedModel, config: PeftConfig
) -> PeftModel:
"""
Create a PEFT model from base model and config.

Args:
model: Base model
config: PEFT configuration

Returns:
PeftModel: Created PEFT model
"""
return get_peft_model(model, config)

def print_trainable_parameters(self, model: Union[PreTrainedModel, PeftModel]):
"""
Print the number of trainable parameters in the model.

Args:
model: The model to analyze
"""
if self.rank == 0:
try:
model.print_trainable_parameters()
except AttributeError as exc:
LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s",
exc,
)
Loading
Loading