Skip to content

Commit e1cffb2

Browse files
committed
fix typing errors
1 parent c6cab95 commit e1cffb2

File tree

4 files changed

+21
-22
lines changed

4 files changed

+21
-22
lines changed

autointent/_dump_tools.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from peft import PeftModel
1111
from pydantic import BaseModel
1212
from sklearn.base import BaseEstimator
13-
from transformers import (
13+
from transformers import ( # type: ignore[attr-defined]
1414
AutoModelForSequenceClassification,
1515
AutoTokenizer,
1616
PreTrainedModel,
@@ -116,29 +116,31 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
116116
if val._is_prompt_learning: # noqa: SLF001
117117
model_path = path / Dumper.peft_models / key
118118
model_path.mkdir(parents=True, exist_ok=True)
119-
val.save_pretrained(model_path / "peft") # save peft config and prompt encoder
120-
val.base_model.save_pretrained(model_path / "base_model") # save bert classifier
119+
# save peft config and prompt encoder
120+
val.save_pretrained(str(model_path / "peft"))
121+
# save bert classifier
122+
val.base_model.save_pretrained(model_path / "base_model") # type: ignore[attr-defined]
121123
else:
122124
model_path = path / Dumper.hf_models / key
123125
model_path.mkdir(parents=True, exist_ok=True)
124126
merged_model: PreTrainedModel = val.merge_and_unload()
125-
merged_model.save_pretrained(model_path)
127+
merged_model.save_pretrained(model_path) # type: ignore[attr-defined]
126128
except Exception as e:
127129
msg = f"Error dumping PeftModel {key}: {e}"
128130
logger.exception(msg)
129131
elif isinstance(val, PreTrainedModel):
130132
model_path = path / Dumper.hf_models / key
131133
model_path.mkdir(parents=True, exist_ok=True)
132134
try:
133-
val.save_pretrained(model_path)
135+
val.save_pretrained(model_path) # type: ignore[attr-defined]
134136
except Exception as e:
135137
msg = f"Error dumping HF model {key}: {e}"
136138
logger.exception(msg)
137139
elif isinstance(val, PreTrainedTokenizer | PreTrainedTokenizerFast):
138140
tokenizer_path = path / Dumper.hf_tokenizers / key
139141
tokenizer_path.mkdir(parents=True, exist_ok=True)
140142
try:
141-
val.save_pretrained(tokenizer_path)
143+
val.save_pretrained(tokenizer_path) # type: ignore[union-attr]
142144
except Exception as e:
143145
msg = f"Error dumping HF tokenizer {key}: {e}"
144146
logger.exception(msg)

autointent/modules/scoring/_bert.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
AutoModelForSequenceClassification,
1212
AutoTokenizer,
1313
DataCollatorWithPadding,
14-
PreTrainedModel,
15-
PreTrainedTokenizer,
16-
PreTrainedTokenizerFast,
1714
Trainer,
1815
TrainingArguments,
1916
)
@@ -29,8 +26,8 @@ class BertScorer(BaseScorer):
2926
name = "bert"
3027
supports_multiclass = True
3128
supports_multilabel = True
32-
_model: PreTrainedModel
33-
_tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast
29+
_model: Any # transformers AutoModel factory returns Any
30+
_tokenizer: Any # transformers AutoTokenizer factory returns Any
3431

3532
def __init__(
3633
self,
@@ -75,11 +72,11 @@ def from_context(
7572
def get_implicit_initialization_params(self) -> dict[str, Any]:
7673
return {"classification_model_config": self.classification_model_config.model_dump()}
7774

78-
def _initialize_model(self) -> None:
75+
def _initialize_model(self) -> Any: # noqa: ANN401
7976
label2id = {i: i for i in range(self._n_classes)}
8077
id2label = {i: i for i in range(self._n_classes)}
8178

82-
self._model = AutoModelForSequenceClassification.from_pretrained(
79+
return AutoModelForSequenceClassification.from_pretrained(
8380
self.classification_model_config.model_name,
8481
trust_remote_code=self.classification_model_config.trust_remote_code,
8582
num_labels=self._n_classes,
@@ -99,7 +96,7 @@ def fit(
9996

10097
self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name)
10198

102-
self._initialize_model()
99+
self._model = self._initialize_model()
103100

104101
use_cpu = self.classification_model_config.device == "cpu"
105102

autointent/modules/scoring/_lora/lora.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
seed=seed,
7878
report_to=report_to,
7979
)
80-
self._lora_config = LoraConfig(**lora_kwargs) # type: ignore[arg-type]
80+
self._lora_config = LoraConfig(**lora_kwargs)
8181

8282
@classmethod
8383
def from_context(
@@ -102,6 +102,6 @@ def from_context(
102102
**lora_kwargs,
103103
)
104104

105-
def _initialize_model(self) -> None:
106-
super()._initialize_model()
107-
self._model = get_peft_model(self._model, self._lora_config)
105+
def _initialize_model(self) -> Any: # noqa: ANN401
106+
model = super()._initialize_model()
107+
return get_peft_model(model, self._lora_config)

autointent/modules/scoring/_ptuning/ptuning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
seed=seed,
7070
report_to=report_to,
7171
)
72-
self._ptuning_config = PromptEncoderConfig(task_type="SEQ_CLS", **ptuning_kwargs) # type: ignore[arg-type]
72+
self._ptuning_config = PromptEncoderConfig(task_type="SEQ_CLS", **ptuning_kwargs)
7373

7474
@classmethod
7575
def from_context(
@@ -108,7 +108,7 @@ def from_context(
108108
**ptuning_kwargs,
109109
)
110110

111-
def _initialize_model(self) -> None:
111+
def _initialize_model(self) -> Any: # noqa: ANN401
112112
"""Initialize the model with P-tuning configuration."""
113-
super()._initialize_model()
114-
self._model = get_peft_model(self._model, self._ptuning_config)
113+
model = super()._initialize_model()
114+
return get_peft_model(model, self._ptuning_config)

0 commit comments

Comments
 (0)