Skip to content

Commit a79a714

Browse files
committed
fix codestyle
1 parent 3bdd40e commit a79a714

File tree

26 files changed

+103
-259
lines changed

26 files changed

+103
-259
lines changed

autointent/context/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def set_dataset(self, dataset: Dataset, force_multilabel: bool = False) -> None:
9090
:param force_multilabel: Whether to force multilabel classification.
9191
"""
9292
self.data_handler = DataHandler(
93-
dataset=dataset, force_multilabel=force_multilabel, random_seed=self.seed,
93+
dataset=dataset,
94+
force_multilabel=force_multilabel,
95+
random_seed=self.seed,
9496
)
9597

9698
def get_best_index(self) -> VectorIndex:

autointent/context/data_handler/dataset.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Dataset(dict[str, HFDataset]):
2929
label_feature = "label"
3030
utterance_feature = "utterance"
3131

32-
def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None:
32+
def __init__(self, *args: Any, intents: list[Intent], **kwargs: Any) -> None: # noqa: ANN401
3333
super().__init__(*args, **kwargs)
3434

3535
self.intents = intents
@@ -51,11 +51,13 @@ def n_classes(self) -> int:
5151
@classmethod
5252
def from_json(cls, filepath: str | Path) -> "Dataset":
5353
from .reader import JsonReader
54+
5455
return JsonReader().read(filepath)
5556

5657
@classmethod
5758
def from_dict(cls, mapping: dict[str, Any]) -> "Dataset":
5859
from .reader import DictReader
60+
5961
return DictReader().read(mapping)
6062

6163
@classmethod
@@ -82,7 +84,7 @@ def to_multilabel(self) -> Self:
8284
self[split_name] = split.map(self._to_multilabel)
8385
return self
8486

85-
def push_to_hub(self, repo_id: str)-> None:
87+
def push_to_hub(self, repo_id: str) -> None:
8688
for split_name, split in self.items():
8789
split.push_to_hub(repo_id, split=split_name)
8890

@@ -95,10 +97,7 @@ def get_tags(self) -> list[Tag]:
9597
for intent in self.intents:
9698
for tag in intent.tags:
9799
tag_mapping[tag].append(intent.id)
98-
return [
99-
Tag(name=tag, intent_ids=intent_ids)
100-
for tag, intent_ids in tag_mapping.items()
101-
]
100+
return [Tag(name=tag, intent_ids=intent_ids) for tag, intent_ids in tag_mapping.items()]
102101

103102
def get_n_classes(self, split: str) -> int:
104103
classes = set()

autointent/context/data_handler/multilabel_generation.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

autointent/context/data_handler/reader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class BaseReader(ABC):
13-
def read(self, *args: Any, **kwargs: Any) -> Dataset:
13+
def read(self, *args: Any, **kwargs: Any) -> Dataset: # noqa: ANN401
1414
dataset_reader = DatasetValidator.validate(self._read(*args, **kwargs))
1515
splits = dataset_reader.model_dump(exclude={"intents"}, exclude_defaults=True)
1616
return Dataset(
@@ -19,8 +19,7 @@ def read(self, *args: Any, **kwargs: Any) -> Dataset:
1919
)
2020

2121
@abstractmethod
22-
def _read(self, *args: Any, **kwargs: Any) -> DatasetReader:
23-
...
22+
def _read(self, *args: Any, **kwargs: Any) -> DatasetReader: ... # noqa: ANN401
2423

2524

2625
class DictReader(BaseReader):

autointent/context/data_handler/sampling.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

autointent/context/data_handler/schemas.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ def _validate_label(self) -> Self:
3737
label = [self.label] if isinstance(self.label, int) else self.label
3838
if not label:
3939
message = (
40-
"The `label` field cannot be empty for a multilabel sample. "
41-
"Please provide at least one valid label."
40+
"The `label` field cannot be empty for a multilabel sample. " "Please provide at least one valid label."
4241
)
4342
raise ValueError(message)
4443
if any(label_ < 0 for label_ in label):

autointent/context/data_handler/stratification.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ class StratifiedSplitter:
2525
"""
2626

2727
def __init__(
28-
self, test_size: float, label_feature: str, random_seed: int, shuffle: bool = True,
28+
self,
29+
test_size: float,
30+
label_feature: str,
31+
random_seed: int,
32+
shuffle: bool = True,
2933
) -> None:
3034
"""
3135
Initialize the StratifiedSplitter.
@@ -52,7 +56,7 @@ def __call__(self, dataset: HFDataset, multilabel: bool) -> tuple[Dataset, Datas
5256
return dataset.select(splits[0]), dataset.select(splits[1])
5357

5458
def _split(self, dataset: HFDataset) -> Sequence[npt.NDArray[np.int_]]:
55-
return train_test_split( # type: ignore[no-any-return]
59+
return train_test_split( # type: ignore[no-any-return]
5660
np.arange(len(dataset)),
5761
test_size=self.test_size,
5862
random_state=self.random_seed,
@@ -81,7 +85,9 @@ def split_dataset(dataset: Dataset, random_seed: int) -> Dataset:
8185
:return: The input dataset with training and testing splits.
8286
"""
8387
splitter = StratifiedSplitter(
84-
test_size=0.25, label_feature=dataset.label_feature, random_seed=random_seed,
88+
test_size=0.25,
89+
label_feature=dataset.label_feature,
90+
random_seed=random_seed,
8591
)
8692
dataset[Split.TRAIN], dataset[Split.TEST] = splitter(dataset[Split.TRAIN], dataset.multilabel)
8793
return dataset

autointent/context/data_handler/tags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import defaultdict
88
from dataclasses import dataclass, field
99

10-
from .schemas import Dataset
10+
from .dataset import Dataset
1111

1212

1313
@dataclass

autointent/context/data_handler/validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def _validate_split(self, split: list[Sample]) -> Self:
3636
intent_ids = {intent.id for intent in self.intents}
3737
for sample in split:
3838
message = (
39-
f"Sample with label {sample.label} references a non-existent intent ID. "
40-
f"Valid IDs are {intent_ids}."
39+
f"Sample with label {sample.label} references a non-existent intent ID. " f"Valid IDs are {intent_ids}."
4140
)
4241
if isinstance(sample.label, int) and sample.label not in intent_ids:
4342
raise ValueError(message)

autointent/context/embedder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def delete(self) -> None:
7474
"""Delete the embedding model and its associated directory."""
7575
self.clear_ram()
7676
shutil.rmtree(
77-
self.dump_dir, ignore_errors=True,
77+
self.dump_dir,
78+
ignore_errors=True,
7879
) # TODO: `ignore_errors=True` is workaround for PermissionError: [WinError 5] Access is denied
7980

8081
def dump(self, path: Path) -> None:

0 commit comments

Comments
 (0)