Skip to content

Commit a75ac12

Browse files
committed
fix typing errors
1 parent bcdf0f2 commit a75ac12

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

src/autointent/_dump_tools/unit_dumpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def load(path: Path, **kwargs: Any) -> "PeftModel": # noqa: ANN401, ARG004
228228
# prompt learning model
229229
ptuning_path = path / "ptuning"
230230
model = transformers.AutoModelForSequenceClassification.from_pretrained(ptuning_path / "base_model")
231-
return peft.PeftModel.from_pretrained(model, ptuning_path / "peft")
231+
return peft.PeftModel.from_pretrained(model, ptuning_path / "peft") # type: ignore[no-any-return]
232232
if (path / "lora").exists():
233233
# merged lora model
234234
lora_path = path / "lora"
@@ -278,7 +278,7 @@ def dump(obj: "PreTrainedTokenizer | PreTrainedTokenizerFast", path: Path, exist
278278
@staticmethod
279279
def load(path: Path, **kwargs: Any) -> "PreTrainedTokenizer | PreTrainedTokenizerFast": # noqa: ANN401, ARG004
280280
transformers = require("transformers", extra="transformers")
281-
return transformers.AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return,no-untyped-call]
281+
return transformers.AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return]
282282

283283
@classmethod
284284
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401

src/autointent/generation/utterances/_adversarial/human_utterance_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def augment(
114114
generated_split = HFDataset.from_list(new_samples)
115115
dataset[split_name] = concatenate_datasets([original_split, generated_split])
116116

117-
return [Sample(**sample) for sample in new_samples]
117+
return [Sample.model_validate(sample) for sample in new_samples]
118118

119119
async def augment_async(
120120
self, dataset: Dataset, split_name: str = Split.TRAIN, update_split: bool = True, n_final_per_class: int = 5

src/autointent/modules/scoring/_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def fit(
151151
) -> None:
152152
self._validate_task(labels)
153153

154-
self._tokenizer = self._AutoTokenizer.from_pretrained(self.classification_model_config.model_name) # type: ignore[no-untyped-call]
154+
self._tokenizer = self._AutoTokenizer.from_pretrained(self.classification_model_config.model_name)
155155
self._model = self._initialize_model()
156156
tokenized_dataset = self._get_tokenized_dataset(utterances, labels)
157157
self._train(tokenized_dataset)

src/autointent/modules/scoring/_gcn/gcn_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def create_correlation_matrix(train_labels: torch.Tensor, num_classes: int, p: f
9090
reweighted_adj = adj_matrix_no_self_loop * weights_p.unsqueeze(1)
9191
reweighted_adj.fill_diagonal_(1 - p)
9292

93-
return cast(torch.Tensor, reweighted_adj)
93+
return reweighted_adj
9494

9595
def set_correlation_matrix(self, train_labels: torch.Tensor) -> None:
9696
corr_matrix = self.create_correlation_matrix(

0 commit comments

Comments
 (0)