Skip to content

Commit c6cab95

Browse files
committed
fix dump tools
1 parent b8fb273 commit c6cab95

File tree

7 files changed

+49
-63
lines changed

7 files changed

+49
-63
lines changed

autointent/_dump_tools.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,16 @@
77
import joblib
88
import numpy as np
99
import numpy.typing as npt
10+
from peft import PeftModel
1011
from pydantic import BaseModel
1112
from sklearn.base import BaseEstimator
13+
from transformers import (
14+
AutoModelForSequenceClassification,
15+
AutoTokenizer,
16+
PreTrainedModel,
17+
PreTrainedTokenizer,
18+
PreTrainedTokenizerFast,
19+
)
1220

1321
from autointent import Embedder, Ranker, VectorIndex
1422
from autointent.configs import CrossEncoderConfig, EmbedderConfig
@@ -34,6 +42,7 @@ class Dumper:
3442
pydantic_models: str = "pydantic"
3543
hf_models = "hf_models"
3644
hf_tokenizers = "hf_tokenizers"
45+
peft_models = "peft_models"
3746

3847
@staticmethod
3948
def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
@@ -52,6 +61,7 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
5261
path / Dumper.pydantic_models,
5362
path / Dumper.hf_models,
5463
path / Dumper.hf_tokenizers,
64+
path / Dumper.peft_models,
5565
]
5666
for subdir in subdirectories:
5767
subdir.mkdir(parents=True, exist_ok=exists_ok)
@@ -101,25 +111,34 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
101111
except Exception as e:
102112
msg = f"Error dumping pydantic model {key}: {e}"
103113
logging.exception(msg)
104-
elif (key == "_model" or "model" in key.lower()) and hasattr(val, "save_pretrained"):
114+
elif isinstance(val, PeftModel):
115+
try:
116+
if val._is_prompt_learning: # noqa: SLF001
117+
model_path = path / Dumper.peft_models / key
118+
model_path.mkdir(parents=True, exist_ok=True)
119+
val.save_pretrained(model_path / "peft") # save peft config and prompt encoder
120+
val.base_model.save_pretrained(model_path / "base_model") # save bert classifier
121+
else:
122+
model_path = path / Dumper.hf_models / key
123+
model_path.mkdir(parents=True, exist_ok=True)
124+
merged_model: PreTrainedModel = val.merge_and_unload()
125+
merged_model.save_pretrained(model_path)
126+
except Exception as e:
127+
msg = f"Error dumping PeftModel {key}: {e}"
128+
logger.exception(msg)
129+
elif isinstance(val, PreTrainedModel):
105130
model_path = path / Dumper.hf_models / key
106131
model_path.mkdir(parents=True, exist_ok=True)
107132
try:
108133
val.save_pretrained(model_path)
109-
class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__}
110-
with (model_path / "class_info.json").open("w") as f:
111-
json.dump(class_info, f)
112134
except Exception as e:
113135
msg = f"Error dumping HF model {key}: {e}"
114136
logger.exception(msg)
115-
elif (key == "_tokenizer" or "tokenizer" in key.lower()) and hasattr(val, "save_pretrained"):
137+
elif isinstance(val, PreTrainedTokenizer | PreTrainedTokenizerFast):
116138
tokenizer_path = path / Dumper.hf_tokenizers / key
117139
tokenizer_path.mkdir(parents=True, exist_ok=True)
118140
try:
119141
val.save_pretrained(tokenizer_path)
120-
class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__}
121-
with (tokenizer_path / "class_info.json").open("w") as f:
122-
json.dump(class_info, f)
123142
except Exception as e:
124143
msg = f"Error dumping HF tokenizer {key}: {e}"
125144
logger.exception(msg)
@@ -202,29 +221,25 @@ def load( # noqa: C901, PLR0912, PLR0915
202221
msg = f"Error loading Pydantic model from {model_dir}: {e}"
203222
logger.exception(msg)
204223
continue
224+
elif child.name == Dumper.peft_models:
225+
for model_dir in child.iterdir():
226+
try:
227+
model = AutoModelForSequenceClassification.from_pretrained(model_dir / "base_model")
228+
hf_models[model_dir.name] = PeftModel.from_pretrained(model, model_dir / "peft")
229+
except Exception as e: # noqa: PERF203
230+
msg = f"Error loading PeftModel {model_dir.name}: {e}"
231+
logger.exception(msg)
205232
elif child.name == Dumper.hf_models:
206233
for model_dir in child.iterdir():
207234
try:
208-
with (model_dir / "class_info.json").open("r") as f:
209-
class_info = json.load(f)
210-
211-
module = __import__(class_info["module"], fromlist=[class_info["name"]])
212-
model_class = getattr(module, class_info["name"])
213-
214-
hf_models[model_dir.name] = model_class.from_pretrained(model_dir)
235+
hf_models[model_dir.name] = AutoModelForSequenceClassification.from_pretrained(model_dir)
215236
except Exception as e: # noqa: PERF203
216237
msg = f"Error loading HF model {model_dir.name}: {e}"
217238
logger.exception(msg)
218239
elif child.name == Dumper.hf_tokenizers:
219240
for tokenizer_dir in child.iterdir():
220241
try:
221-
with (tokenizer_dir / "class_info.json").open("r") as f:
222-
class_info = json.load(f)
223-
224-
module = __import__(class_info["module"], fromlist=[class_info["name"]])
225-
tokenizer_class = getattr(module, class_info["name"])
226-
227-
hf_tokenizers[tokenizer_dir.name] = tokenizer_class.from_pretrained(tokenizer_dir)
242+
hf_tokenizers[tokenizer_dir.name] = AutoTokenizer.from_pretrained(tokenizer_dir)
228243
except Exception as e: # noqa: PERF203
229244
msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}"
230245
logger.exception(msg)

autointent/modules/scoring/_bert.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
AutoModelForSequenceClassification,
1212
AutoTokenizer,
1313
DataCollatorWithPadding,
14+
PreTrainedModel,
15+
PreTrainedTokenizer,
16+
PreTrainedTokenizerFast,
1417
Trainer,
1518
TrainingArguments,
1619
)
@@ -26,8 +29,8 @@ class BertScorer(BaseScorer):
2629
name = "bert"
2730
supports_multiclass = True
2831
supports_multilabel = True
29-
_model: Any
30-
_tokenizer: Any
32+
_model: PreTrainedModel
33+
_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast
3134

3235
def __init__(
3336
self,

autointent/modules/scoring/_lora/lora.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any
44

55
from peft import LoraConfig, get_peft_model
6-
from transformers import AutoModelForSequenceClassification
76

87
from autointent import Context
98
from autointent._callbacks import REPORTERS_NAMES
@@ -59,10 +58,6 @@ class BERTLoRAScorer(BertScorer):
5958
"""
6059

6160
name = "lora"
62-
supports_multiclass = True
63-
supports_multilabel = True
64-
_model: Any
65-
_tokenizer: Any
6661

6762
def __init__(
6863
self,
@@ -72,7 +67,7 @@ def __init__(
7267
learning_rate: float = 5e-5,
7368
seed: int = 0,
7469
report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type]
75-
**lora_kwargs: dict[str, Any],
70+
**lora_kwargs: Any, # noqa: ANN401
7671
) -> None:
7772
super().__init__(
7873
classification_model_config=classification_model_config,
@@ -93,7 +88,7 @@ def from_context(
9388
batch_size: int = 8,
9489
learning_rate: float = 5e-5,
9590
seed: int = 0,
96-
**lora_kwargs: dict[str, Any],
91+
**lora_kwargs: Any, # noqa: ANN401
9792
) -> "BERTLoRAScorer":
9893
if classification_model_config is None:
9994
classification_model_config = context.resolve_transformer()
@@ -108,10 +103,5 @@ def from_context(
108103
)
109104

110105
def _initialize_model(self) -> None:
111-
self._model = AutoModelForSequenceClassification.from_pretrained(
112-
self.classification_model_config.model_name,
113-
num_labels=self._n_classes,
114-
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
115-
trust_remote_code=self.classification_model_config.trust_remote_code,
116-
)
106+
super()._initialize_model()
117107
self._model = get_peft_model(self._model, self._lora_config)

autointent/modules/scoring/_ptuning/ptuning.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22

33
from typing import Any
44

5-
import torch
65
from peft import PromptEncoderConfig, get_peft_model
7-
from transformers import (
8-
AutoModelForSequenceClassification,
9-
)
106

117
from autointent import Context
128
from autointent._callbacks import REPORTERS_NAMES
@@ -54,10 +50,6 @@ class PTuningScorer(BertScorer):
5450
"""
5551

5652
name = "ptuning"
57-
supports_multiclass = True
58-
supports_multilabel = True
59-
_model: Any
60-
_tokenizer: Any
6153

6254
def __init__(
6355
self,
@@ -67,7 +59,7 @@ def __init__(
6759
learning_rate: float = 5e-5,
6860
seed: int = 0,
6961
report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type]
70-
**ptuning_kwargs: dict[str, Any],
62+
**ptuning_kwargs: Any, # noqa: ANN401
7163
) -> None:
7264
super().__init__(
7365
classification_model_config=classification_model_config,
@@ -77,8 +69,7 @@ def __init__(
7769
seed=seed,
7870
report_to=report_to,
7971
)
80-
self._ptuning_config = PromptEncoderConfig(**ptuning_kwargs) # type: ignore[arg-type]
81-
torch.manual_seed(seed)
72+
self._ptuning_config = PromptEncoderConfig(task_type="SEQ_CLS", **ptuning_kwargs) # type: ignore[arg-type]
8273

8374
@classmethod
8475
def from_context(
@@ -89,7 +80,7 @@ def from_context(
8980
batch_size: int = 8,
9081
learning_rate: float = 5e-5,
9182
seed: int = 0,
92-
**ptuning_kwargs: dict[str, Any],
83+
**ptuning_kwargs: Any, # noqa: ANN401
9384
) -> "PTuningScorer":
9485
"""Create a PTuningScorer instance using a Context object.
9586
@@ -119,12 +110,5 @@ def from_context(
119110

120111
def _initialize_model(self) -> None:
121112
"""Initialize the model with P-tuning configuration."""
122-
model_name = self.classification_model_config.model_name
123-
self._model = AutoModelForSequenceClassification.from_pretrained(
124-
model_name,
125-
num_labels=self._n_classes,
126-
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
127-
trust_remote_code=self.classification_model_config.trust_remote_code,
128-
return_dict=True,
129-
)
113+
super()._initialize_model()
130114
self._model = get_peft_model(self._model, self._ptuning_config)

tests/assets/configs/multiclass.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
classification_model_config: ["prajjwal1/bert-tiny"]
4949
num_train_epochs: [1]
5050
batch_size: [8, 16]
51-
task_type: ["SEQ_CLS"]
5251
num_virtual_tokens: [10, 20]
5352
- node_type: decision
5453
target_metric: decision_accuracy

tests/assets/configs/multilabel.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
classification_model_config: ["prajjwal1/bert-tiny"]
3737
num_train_epochs: [1]
3838
batch_size: [8]
39-
task_type: ["SEQ_CLS"]
4039
num_virtual_tokens: [10, 20]
4140
- module_name: lora
4241
classification_model_config:

tests/modules/scoring/test_ptuning.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def test_ptuning_scorer_dump_load(dataset):
1717
classification_model_config="prajjwal1/bert-tiny",
1818
num_train_epochs=1,
1919
batch_size=8,
20-
task_type="SEQ_CLS",
2120
num_virtual_tokens=10,
2221
seed=42,
2322
)
@@ -38,7 +37,6 @@ def test_ptuning_scorer_dump_load(dataset):
3837
classification_model_config="prajjwal1/bert-tiny",
3938
num_train_epochs=1,
4039
batch_size=8,
41-
task_type="SEQ_CLS",
4240
num_virtual_tokens=10,
4341
seed=42,
4442
)
@@ -66,7 +64,6 @@ def test_ptuning_prediction(dataset):
6664
classification_model_config="prajjwal1/bert-tiny",
6765
num_train_epochs=1,
6866
batch_size=8,
69-
task_type="SEQ_CLS",
7067
num_virtual_tokens=10,
7168
seed=42,
7269
)
@@ -106,7 +103,6 @@ def test_ptuning_cache_clearing(dataset):
106103
classification_model_config="prajjwal1/bert-tiny",
107104
num_train_epochs=1,
108105
batch_size=8,
109-
task_type="SEQ_CLS",
110106
num_virtual_tokens=20,
111107
seed=42,
112108
)

0 commit comments

Comments
 (0)