Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions src/axolotl/loaders/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.loaders.adapters.builders.factory import AdapterBuilderFactory

LOG = get_logger(__name__)

Expand Down Expand Up @@ -165,19 +166,47 @@ def load_adapter(
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Return type must include None for config_only=True.

Function can return (None, config) but the annotation excludes None for the model.

Apply:

 def load_adapter(
@@
-    config_only: bool = False,
-) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
+    config_only: bool = False,
+) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
🤖 Prompt for AI Agents
In src/axolotl/loaders/adapter.py around lines 169-170, the return type
annotation disallows returning None for the model though the function can return
(None, config) when config_only=True; update the annotation to allow None for
the model (e.g. tuple[PreTrainedModel | PeftModel | PeftMixedModel | None,
PeftConfig | None]) and audit callers to handle the possible None model
accordingly.

if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
peft_model, lora_config = load_lora(model, cfg, inference=inference)
return peft_model, lora_config
if adapter == "llama-adapter":
peft_model, lora_config = load_llama_adapter(model, cfg)
return peft_model, lora_config

raise NotImplementedError(f"{adapter} PEFT adapter not available")
try:
if adapter is None:
return model, None
builder = AdapterBuilderFactory.create_builder(adapter, cfg)

if not builder:
LOG.warning(f"No builder found for adapter type '{adapter}'")
return model, None

config = builder.build_config(model)

if config_only:
return None, config

if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()

model = builder.build_model(model, config)
return model, config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Propagate inference to builder and remove dead check.

  • inference is not forwarded; pretrained adapters will be loaded as trainable during inference.
  • AdapterBuilderFactory.create_builder never returns None, so the "if not builder" branch is unreachable.

Apply:

-        builder = AdapterBuilderFactory.create_builder(adapter, cfg)
-
-        if not builder:
-            LOG.warning(f"No builder found for adapter type '{adapter}'")
-            return model, None
+        builder = AdapterBuilderFactory.create_builder(adapter, cfg)
@@
-        model = builder.build_model(model, config)
+        model = builder.build_model(model, config, inference=inference)
         return model, config

And update builder signatures accordingly (outside this file):

--- a/src/axolotl/loaders/adapters/builders/base.py
+++ b/src/axolotl/loaders/adapters/builders/base.py
@@
-    def build_model(self, model: PreTrainedModel, config: PeftConfig) -> PeftModel:
+    def build_model(
+        self, model: PreTrainedModel, config: PeftConfig, *, inference: bool = False
+    ) -> PeftModel:
@@
-        if self.cfg.lora_model_dir:
-            model = self.load_pretrained_adapter(model)
+        if self.cfg.lora_model_dir:
+            model = self.load_pretrained_adapter(model, inference=inference)
         else:
             model = self.create_peft_model(model, config)
--- a/src/axolotl/loaders/adapters/builders/lora.py
+++ b/src/axolotl/loaders/adapters/builders/lora.py
@@
-    def build_model(self, model: PreTrainedModel, config: LoraConfig) -> PeftModel:
+    def build_model(
+        self, model: PreTrainedModel, config: LoraConfig, *, inference: bool = False
+    ) -> PeftModel:
@@
-        if self.cfg.lora_model_dir:
+        if self.cfg.lora_model_dir:
             LOG.debug("Loading pretrained PEFT - LoRA")
-            model = self.load_pretrained_adapter(model)
+            model = self.load_pretrained_adapter(model, inference=inference)
         else:
             model = self.create_peft_model(model, config)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
if adapter is None:
return model, None
builder = AdapterBuilderFactory.create_builder(adapter, cfg)
if not builder:
LOG.warning(f"No builder found for adapter type '{adapter}'")
return model, None
config = builder.build_config(model)
if config_only:
return None, config
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
model = builder.build_model(model, config)
return model, config
try:
if adapter is None:
return model, None
builder = AdapterBuilderFactory.create_builder(adapter, cfg)
config = builder.build_config(model)
if config_only:
return None, config
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
model = builder.build_model(model, config, inference=inference)
return model, config
🧰 Tools
🪛 Ruff (0.12.2)

189-189: Consider moving this statement to an else block

(TRY300)

except Exception as e:
LOG.debug(
f"Builder pattern failed, falling back to legacy adapter loading: {e}"
)

if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
peft_model, lora_config = load_lora(
model, cfg, inference=inference, config_only=config_only
)
return peft_model, lora_config
if adapter == "llama-adapter":
peft_model, lora_config = load_llama_adapter(model, cfg)
return peft_model, lora_config

raise NotImplementedError(f"{adapter} PEFT adapter not available") from None


def load_llama_adapter(
Expand Down
13 changes: 13 additions & 0 deletions src/axolotl/loaders/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Adapters package."""

from .builders import (
AdapterBuilderFactory,
BaseAdapterBuilder,
LoraAdapterBuilder,
)

__all__ = [
"AdapterBuilderFactory",
"BaseAdapterBuilder",
"LoraAdapterBuilder",
]
11 changes: 11 additions & 0 deletions src/axolotl/loaders/adapters/builders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Adapter builders package."""

from .base import BaseAdapterBuilder
from .factory import AdapterBuilderFactory
from .lora import LoraAdapterBuilder

__all__ = [
"BaseAdapterBuilder",
"AdapterBuilderFactory",
"LoraAdapterBuilder",
]
249 changes: 249 additions & 0 deletions src/axolotl/loaders/adapters/builders/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import os
from abc import abstractmethod, ABC
from typing import Optional, Union, List, Any, Dict
from transformers import PreTrainedModel
from peft import PeftConfig, PeftModel, get_peft_model, LoraConfig

from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


class BaseAdapterBuilder(ABC):
"""Base class for adapter builders"""

def __init__(self, cfg: DictDefault):
self.cfg = cfg
self.rank = int(os.environ.get("LOCAL_RANK", 0))

@abstractmethod
def build_config(self, model: PreTrainedModel, **kwargs) -> PeftConfig:
"""Build the PEFT configuration"""
target_modules = self.prepare_target_modules(model)
target_parameters = self.prepare_target_parameters()

config_kwargs = self.build_common_config_kwargs()
config_kwargs.update(kwargs)

lora_config = LoraConfig(
r=self.cfg.lora_r,
lora_alpha=self.cfg.lora_alpha,
target_modules=target_modules,
target_parameters=target_parameters,
**config_kwargs,
)
return lora_config

@abstractmethod
def build_model(self, model: PreTrainedModel, config: PeftConfig) -> PeftModel:
"""Build the PEFT model"""
self.setup_quantization_for_training(model)

if self.cfg.lora_model_dir:
model = self.load_pretrained_adapter(model)
else:
model = self.create_peft_model(model, config)

self.print_trainable_parameters(model)
self.setup_quantization_for_training_post_build(model)

return model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Propagate inference through the base builder.

Ensures consistent behavior across all builders when loading pretrained adapters.

-    def build_model(self, model: PreTrainedModel, config: PeftConfig) -> PeftModel:
+    def build_model(
+        self, model: PreTrainedModel, config: PeftConfig, *, inference: bool = False
+    ) -> PeftModel:
         """Build the PEFT model"""
         self.setup_quantization_for_training(model)
 
         if self.cfg.lora_model_dir:
-            model = self.load_pretrained_adapter(model)
+            model = self.load_pretrained_adapter(model, inference=inference)
         else:
             model = self.create_peft_model(model, config)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@abstractmethod
def build_model(self, model: PreTrainedModel, config: PeftConfig) -> PeftModel:
"""Build the PEFT model"""
self.setup_quantization_for_training(model)
if self.cfg.lora_model_dir:
model = self.load_pretrained_adapter(model)
else:
model = self.create_peft_model(model, config)
self.print_trainable_parameters(model)
self.setup_quantization_for_training_post_build(model)
return model
@abstractmethod
def build_model(
self, model: PreTrainedModel, config: PeftConfig, *, inference: bool = False
) -> PeftModel:
"""Build the PEFT model"""
self.setup_quantization_for_training(model)
if self.cfg.lora_model_dir:
model = self.load_pretrained_adapter(model, inference=inference)
else:
model = self.create_peft_model(model, config)
self.print_trainable_parameters(model)
self.setup_quantization_for_training_post_build(model)
return model
🤖 Prompt for AI Agents
In src/axolotl/loaders/adapters/builders/base.py around lines 38-51, the base
build_model currently branches to load a pretrained adapter without propagating
the builder's inference/training mode or relevant flags; update build_model so
that when self.cfg.lora_model_dir is truthy you call load_pretrained_adapter in
a way that preserves the builder's inference mode and any quantization/training
flags (e.g., pass an inference/training argument or the builder's context), and
if load_pretrained_adapter signature needs it, add a parameter to accept and
forward that flag; ensure both code paths (load_pretrained_adapter and
create_peft_model) receive the same inference/training/quantization context,
then keep the existing setup_quantization_for_training_post_build and
print_trainable_parameters calls and return the model.


def prepare_target_modules(
self,
model: PreTrainedModel,
target_modules: Optional[Union[str, List[str]]] = None,
) -> List[str]:
"""
Prepare and validate target modules for the adapter.
Args:
model: The base model
target_modules: User-specified target modules
Returns:
List[str]: Processed list of target modules
"""

lora_target_modules: Union[str, List[str]] = (
target_modules or self.cfg.lora_target_modules or []
)

if self.cfg.lora_target_linear:
from axolotl.loaders.adapter import find_all_linear_names

linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
if lora_target_modules
else []
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
elif isinstance(lora_target_modules, str):
lora_target_modules = [lora_target_modules]
elif lora_target_modules is None:
lora_target_modules = []

return lora_target_modules

def prepare_target_parameters(
self, target_parameters: Optional[Union[str, List[str]]] = None
) -> List[str]:
"""
Prepare target parameters for the adapter.
Args:
target_parameters: User-specified target parameters
Returns:
List[str]: Processed list of target parameters
"""
result = target_parameters or self.cfg.lora_target_parameters or []
if isinstance(result, str):
return [result]
elif isinstance(result, list):
return result
else:
return []

def build_common_config_kwargs(self) -> Dict[str, Any]:
"""
Build common configuration kwargs shared across adapter types.
Returns:
Dict[str, Any]: Common configuration parameters
"""
config_kwargs = {}

# LoftQ configuration
loftq_bits = (
self.cfg.peft
and self.cfg.peft.loftq_config
and self.cfg.peft.loftq_config.loftq_bits
)
if loftq_bits:
from peft import LoftQConfig

config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
config_kwargs["init_lora_weights"] = "loftq"

# LoRA weight initialization
if self.cfg.peft_init_lora_weights:
config_kwargs["init_lora_weights"] = self.cfg.peft_init_lora_weights

# DoRA configuration
if self.cfg.peft_use_dora:
config_kwargs["use_dora"] = self.cfg.peft_use_dora
LOG.info("Initializing LoRA weights using DoRA. This might take longer.")

# RSLoRA configuration
if self.cfg.peft_use_rslora:
config_kwargs["use_rslora"] = self.cfg.peft_use_rslora

# Layer replication
if self.cfg.peft_layer_replication:
config_kwargs["layer_replication"] = self.cfg.peft_layer_replication

return config_kwargs

def setup_quantization_for_training(self, model: Union[PreTrainedModel, PeftModel]):
"""
Setup quantization metadata for training.
Args:
model: The model to setup quantization for
"""
from axolotl.loaders.adapter import setup_quantized_meta_for_peft

if (
self.cfg.fsdp_config
and self.cfg.adapter
and self.cfg.fsdp_config.cpu_ram_efficient_loading
and self.rank != 0
):
setup_quantized_meta_for_peft(model)

def setup_quantization_for_training_post_build(
self, model: Union[PreTrainedModel, PeftModel]
):
"""
Setup quantization metadata after model building for training.
Args:
model: The model to setup quantization for
"""
from axolotl.loaders.adapter import setup_quantized_peft_meta_for_training

if (
self.cfg.fsdp_config
and self.cfg.adapter
and self.cfg.fsdp_config.cpu_ram_efficient_loading
and self.rank != 0
):
setup_quantized_peft_meta_for_training(model)

def load_pretrained_adapter(
self, model: PreTrainedModel, inference: bool = False
) -> Union[PreTrainedModel, PeftModel]:
"""
Load a pretrained adapter from a directory.
Args:
model: Base model to load adapter onto
inference: Whether this is for inference mode
Returns:
PeftModel: Model with loaded adapter
"""

if not self.cfg.lora_model_dir:
return model

LOG.debug(f"Loading pretrained PEFT - {self.__class__.__name__}")
model_kwargs: Dict[str, Any] = {}

if self.cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}

return PeftModel.from_pretrained(
model,
self.cfg.lora_model_dir,
is_trainable=(not inference),
**model_kwargs,
)

def create_peft_model(
self, model: PreTrainedModel, config: PeftConfig
) -> PeftModel:
"""
Create a PEFT model from base model and config.
Args:
model: Base model
config: PEFT configuration
Returns:
PeftModel: Created PEFT model
"""
return get_peft_model(model, config)

def print_trainable_parameters(self, model: Union[PreTrainedModel, PeftModel]):
"""
Print the number of trainable parameters in the model.
Args:
model: The model to analyze
"""
if self.rank == 0:
try:
model.print_trainable_parameters()
except AttributeError as exc:
LOG.warning(
"Exception caught during model.print_trainable_parameters(): %s",
exc,
)
Loading