Skip to content

Commit dbc1724

Browse files
committed
implement peft lazy importing
1 parent 92f1cf3 commit dbc1724

File tree

6 files changed

+50
-20
lines changed

6 files changed

+50
-20
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ dependencies = [
4444
"xxhash (>=3.5.0,<4.0.0)",
4545
"python-dotenv (>=1.0.1,<2.0.0)",
4646
"transformers[torch] (>=4.49.0,<5.0.0)",
47-
"peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)",
4847
"aiometer (>=1.0.0,<2.0.0)",
4948
"aiofiles (>=24.1.0,<25.0.0)",
5049
"threadpoolctl (>=3.0.0,<4.0.0)",
5150
]
5251

5352
[project.optional-dependencies]
5453
catboost = ["catboost (>=1.2.8,<2.0.0)"]
54+
peft = ["peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)"]
5555
dspy = [
5656
"dspy (>=2.6.5,<3.0.0)",
5757
]

src/autointent/_dump_tools/unit_dumpers.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import joblib
99
import numpy as np
1010
import numpy.typing as npt
11-
from peft import PeftModel
1211
from pydantic import BaseModel
1312
from sklearn.base import BaseEstimator
1413
from transformers import (
@@ -28,6 +27,7 @@
2827

2928
if TYPE_CHECKING:
3029
from catboost import CatBoostClassifier
30+
from peft import PeftModel
3131

3232
T = TypeVar("T")
3333
logger = logging.getLogger(__name__)
@@ -207,11 +207,11 @@ def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
207207
return isinstance(obj, BaseModel)
208208

209209

210-
class PeftModelDumper(BaseObjectDumper[PeftModel]):
210+
class PeftModelDumper(BaseObjectDumper["PeftModel"]):
211211
dir_or_file_name = "peft_models"
212212

213213
@staticmethod
214-
def dump(obj: PeftModel, path: Path, exists_ok: bool) -> None:
214+
def dump(obj: "PeftModel", path: Path, exists_ok: bool) -> None:
215215
path.mkdir(parents=True, exist_ok=exists_ok)
216216
if obj._is_prompt_learning: # noqa: SLF001
217217
# strategy to save prompt learning models: save prompt encoder and bert classifier separately
@@ -227,12 +227,13 @@ def dump(obj: PeftModel, path: Path, exists_ok: bool) -> None:
227227
merged_model.save_pretrained(lora_path)
228228

229229
@staticmethod
230-
def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401, ARG004
230+
def load(path: Path, **kwargs: Any) -> "PeftModel": # noqa: ANN401, ARG004
231+
peft = require("peft", extra="peft")
231232
if (path / "ptuning").exists():
232233
# prompt learning model
233234
ptuning_path = path / "ptuning"
234235
model = AutoModelForSequenceClassification.from_pretrained(ptuning_path / "base_model")
235-
return PeftModel.from_pretrained(model, ptuning_path / "peft")
236+
return peft.PeftModel.from_pretrained(model, ptuning_path / "peft")
236237
if (path / "lora").exists():
237238
# merged lora model
238239
lora_path = path / "lora"
@@ -242,7 +243,11 @@ def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401, ARG004
242243

243244
@classmethod
244245
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
245-
return isinstance(obj, PeftModel)
246+
try:
247+
peft = require("peft", extra="peft")
248+
return isinstance(obj, peft.PeftModel)
249+
except ImportError:
250+
return False
246251

247252

248253
class HFModelDumper(BaseObjectDumper[PreTrainedModel]):

src/autointent/modules/scoring/_lora/lora.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""BertScorer class for transformer-based classification with LoRA."""
22

33
from pathlib import Path
4-
from typing import Any, Literal
5-
6-
from peft import LoraConfig, get_peft_model
4+
from typing import TYPE_CHECKING, Any, Literal
75

86
from autointent import Context
97
from autointent._callbacks import REPORTERS_NAMES
108
from autointent._dump_tools import Dumper
9+
from autointent._utils import require
1110
from autointent.configs import EarlyStoppingConfig, HFModelConfig
1211
from autointent.modules.scoring._bert import BertScorer
1312

13+
if TYPE_CHECKING:
14+
from peft import LoraConfig
15+
1416

1517
class BERTLoRAScorer(BertScorer):
1618
"""BERTLoRAScorer class for transformer-based classification with LoRA (Low-Rank Adaptation).
@@ -56,6 +58,8 @@ class BERTLoRAScorer(BertScorer):
5658

5759
name = "lora"
5860

61+
_lora_config: "LoraConfig"
62+
5963
def __init__(
6064
self,
6165
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
@@ -67,6 +71,11 @@ def __init__(
6771
print_progress: bool = False,
6872
**lora_kwargs: Any, # noqa: ANN401
6973
) -> None:
74+
# Lazy import peft
75+
peft = require("peft", extra="peft")
76+
self._LoraConfig = peft.LoraConfig
77+
self._get_peft_model = peft.get_peft_model
78+
7079
# early stopping doesnt work with lora for now https://github.com/huggingface/transformers/issues/38130
7180
early_stopping_config = EarlyStoppingConfig(metric=None) # disable early stopping
7281

@@ -80,7 +89,7 @@ def __init__(
8089
early_stopping_config=early_stopping_config,
8190
print_progress=print_progress,
8291
)
83-
self._lora_config = LoraConfig(**lora_kwargs)
92+
self._lora_config = self._LoraConfig(**lora_kwargs)
8493

8594
@classmethod
8695
def from_context(
@@ -107,7 +116,7 @@ def from_context(
107116

108117
def _initialize_model(self) -> Any: # noqa: ANN401
109118
model = super()._initialize_model()
110-
return get_peft_model(model, self._lora_config)
119+
return self._get_peft_model(model, self._lora_config)
111120

112121
def dump(self, path: str) -> None:
113-
Dumper.dump(self, Path(path), exclude=[LoraConfig])
122+
Dumper.dump(self, Path(path), exclude=[self._LoraConfig])

src/autointent/modules/scoring/_ptuning/ptuning.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
"""PTuningScorer class for ptuning-based classification."""
22

33
from pathlib import Path
4-
from typing import Any, Literal
4+
from typing import TYPE_CHECKING, Any, Literal
55

6-
from peft import PromptEncoderConfig, PromptEncoderReparameterizationType, TaskType, get_peft_model
76
from pydantic import PositiveInt
87

98
from autointent import Context
109
from autointent._callbacks import REPORTERS_NAMES
1110
from autointent._dump_tools import Dumper
11+
from autointent._utils import require
1212
from autointent.configs import EarlyStoppingConfig, HFModelConfig
1313
from autointent.modules.scoring._bert import BertScorer
1414

15+
if TYPE_CHECKING:
16+
from peft import PromptEncoderConfig
17+
1518

1619
class PTuningScorer(BertScorer):
1720
"""PEFT P-tuning scorer.
@@ -47,6 +50,8 @@ class PTuningScorer(BertScorer):
4750

4851
name = "ptuning"
4952

53+
_ptuning_config: "PromptEncoderConfig"
54+
5055
def __init__( # noqa: PLR0913
5156
self,
5257
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
@@ -64,6 +69,13 @@ def __init__( # noqa: PLR0913
6469
print_progress: bool = False,
6570
**ptuning_kwargs: Any, # noqa: ANN401
6671
) -> None:
72+
# Lazy import peft
73+
peft = require("peft", extra="peft")
74+
self._PromptEncoderConfig = peft.PromptEncoderConfig
75+
self._PromptEncoderReparameterizationType = peft.PromptEncoderReparameterizationType
76+
self._TaskType = peft.TaskType
77+
self._get_peft_model = peft.get_peft_model
78+
6779
super().__init__(
6880
classification_model_config=classification_model_config,
6981
num_train_epochs=num_train_epochs,
@@ -74,9 +86,9 @@ def __init__( # noqa: PLR0913
7486
early_stopping_config=early_stopping_config,
7587
print_progress=print_progress,
7688
)
77-
self._ptuning_config = PromptEncoderConfig(
78-
task_type=TaskType.SEQ_CLS,
79-
encoder_reparameterization_type=PromptEncoderReparameterizationType(encoder_reparameterization_type),
89+
self._ptuning_config = self._PromptEncoderConfig(
90+
task_type=self._TaskType.SEQ_CLS,
91+
encoder_reparameterization_type=self._PromptEncoderReparameterizationType(encoder_reparameterization_type),
8092
num_virtual_tokens=num_virtual_tokens,
8193
encoder_dropout=encoder_dropout,
8294
encoder_hidden_size=encoder_hidden_size,
@@ -139,7 +151,7 @@ def from_context( # noqa: PLR0913
139151
def _initialize_model(self) -> Any: # noqa: ANN401
140152
"""Initialize the model with P-tuning configuration."""
141153
model = super()._initialize_model()
142-
return get_peft_model(model, self._ptuning_config)
154+
return self._get_peft_model(model, self._ptuning_config)
143155

144156
def dump(self, path: str) -> None:
145-
Dumper.dump(self, Path(path), exclude=[PromptEncoderConfig])
157+
Dumper.dump(self, Path(path), exclude=[self._PromptEncoderConfig])

tests/modules/scoring/test_lora.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from autointent.context.data_handler import DataHandler
99
from autointent.modules import BERTLoRAScorer
1010

11+
pytest.importorskip("peft")
12+
1113

1214
def test_lora_scorer_dump_load(dataset):
1315
"""Test that BERTLoRAScorer can be saved and loaded while preserving predictions."""

tests/modules/scoring/test_ptuning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from autointent.context.data_handler import DataHandler
99
from autointent.modules import PTuningScorer
1010

11+
pytest.importorskip("peft")
12+
1113

1214
def test_ptuning_scorer_dump_load(dataset):
1315
"""Test that PTuningScorer can be saved and loaded while preserving predictions."""

0 commit comments

Comments
 (0)