Skip to content

Commit 77137ae

Browse files
committed
Fix docstrings
1 parent 32e5e5f commit 77137ae

File tree

5 files changed

+53
-12
lines changed

5 files changed

+53
-12
lines changed

autointent/context/data_handler/_data_handler.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,63 +77,98 @@ def multilabel(self) -> bool:
7777

7878
def train_utterances(self, idx: int | None = None) -> list[str]:
7979
"""
80-
Get the training utterances.
80+
Retrieve training utterances from the dataset.
8181
82+
If a specific training split index is provided, retrieves utterances
83+
from the indexed training split. Otherwise, retrieves utterances from
84+
the primary training split.
85+
86+
:param idx: Optional index for a specific training split.
8287
:return: List of training utterances.
8388
"""
8489
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
8590
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
8691

8792
def train_labels(self, idx: int | None = None) -> list[LabelType]:
8893
"""
89-
Get the training labels.
94+
Retrieve training labels from the dataset.
95+
96+
If a specific training split index is provided, retrieves labels
97+
from the indexed training split. Otherwise, retrieves labels from
98+
the primary training split.
9099
100+
:param idx: Optional index for a specific training split.
91101
:return: List of training labels.
92102
"""
93103
split = f"{Split.TRAIN}_{idx}" if idx is not None else Split.TRAIN
94104
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
95105

96106
def validation_utterances(self, idx: int | None = None) -> list[str]:
97107
"""
98-
Get the validation utterances.
108+
Retrieve validation utterances from the dataset.
99109
110+
If a specific validation split index is provided, retrieves utterances
111+
from the indexed validation split. Otherwise, retrieves utterances from
112+
the primary validation split.
113+
114+
:param idx: Optional index for a specific validation split.
100115
:return: List of validation utterances.
101116
"""
102117
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
103118
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
104119

105120
def validation_labels(self, idx: int | None = None) -> list[LabelType]:
106121
"""
107-
Get the validatio labels.
122+
Retrieve validation labels from the dataset.
123+
124+
If a specific validation split index is provided, retrieves labels
125+
from the indexed validation split. Otherwise, retrieves labels from
126+
the primary validation split.
108127
109-
:return: List of validatio labels.
128+
:param idx: Optional index for a specific validation split.
129+
:return: List of validation labels.
110130
"""
111131
split = f"{Split.VALIDATION}_{idx}" if idx is not None else Split.VALIDATION
112132
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
113133

114134
def test_utterances(self, idx: int | None = None) -> list[str]:
115135
"""
116-
Get the test utterances.
136+
Retrieve test utterances from the dataset.
117137
138+
If a specific test split index is provided, retrieves utterances
139+
from the indexed test split. Otherwise, retrieves utterances from
140+
the primary test split.
141+
142+
:param idx: Optional index for a specific test split.
118143
:return: List of test utterances.
119144
"""
120145
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
121146
return cast(list[str], self.dataset[split][self.dataset.utterance_feature])
122147

123148
def test_labels(self, idx: int | None = None) -> list[LabelType]:
124149
"""
125-
Get the test labels.
150+
Retrieve test labels from the dataset.
151+
152+
If a specific test split index is provided, retrieves labels
153+
from the indexed test split. Otherwise, retrieves labels from
154+
the primary test split.
126155
156+
:param idx: Optional index for a specific test split.
127157
:return: List of test labels.
128158
"""
129159
split = f"{Split.TEST}_{idx}" if idx is not None else Split.TEST
130160
return cast(list[LabelType], self.dataset[split][self.dataset.label_feature])
131161

132162
def oos_utterances(self, idx: int | None = None) -> list[str]:
133163
"""
134-
Get the out-of-scope utterances.
164+
Retrieve out-of-scope (OOS) utterances from the dataset.
165+
166+
If the dataset contains out-of-scope samples, retrieves the utterances
167+
from the specified OOS split index (if provided) or the primary OOS split.
168+
Returns an empty list if no OOS samples are available in the dataset.
135169
136-
:return: List of out-of-scope utterances if available, otherwise an empty list.
170+
:param idx: Optional index for a specific OOS split.
171+
:return: List of out-of-scope utterances, or an empty list if unavailable.
137172
"""
138173
if self.has_oos_samples():
139174
split = f"{Split.OOS}_{idx}" if idx is not None else Split.OOS

autointent/context/optimization_info/_optimization_info.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ def get_best_oos_scores(self, split: Literal["train", "validation", "test"]) ->
178178
"""
179179
Retrieve the out-of-scope scores from the best scorer node.
180180
181-
:return: Out-of-scope scores as a numpy array.
181+
:param split: The data split for which to retrieve the OOS scores.
182+
Must be one of "train", "validation", or "test".
183+
:return: A numpy array containing OOS scores for the specified split,
184+
or `None` if no OOS scores are available.
182185
"""
183186
best_scorer_artifact: ScorerArtifact = self._get_best_artifact(node_type=NodeType.scoring) # type: ignore[assignment]
184187
if best_scorer_artifact.oos_scores is not None:
@@ -219,7 +222,7 @@ def get_inference_nodes_config(self) -> list[InferenceNodeConfig]:
219222
)
220223
return res
221224

222-
def get_best_module(self, node_type: str) -> "Module | None":
225+
def _get_best_module(self, node_type: str) -> "Module | None":
223226
"""
224227
Retrieve the best module for a specific node type.
225228
@@ -237,5 +240,5 @@ def get_best_modules(self) -> dict[NodeTypeType, "Module"]:
237240
238241
:return: Dictionary of the best modules for each node type.
239242
"""
240-
res = {nt: self.get_best_module(nt) for nt in NODE_TYPES}
243+
res = {nt: self._get_best_module(nt) for nt in NODE_TYPES}
241244
return {nt: m for nt, m in res.items() if m is not None} # type: ignore[misc]

autointent/modules/prediction/_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def score(
5050
Calculate metric on test set and return metric value.
5151
5252
:param context: Context to score
53+
:param split: Target split
5354
:param metric_fn: Metric function
5455
:return: Score
5556
"""

autointent/modules/retrieval/_vectordb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def score(
121121
Evaluate the retrieval model using a specified metric function.
122122
123123
:param context: The context containing test data and labels.
124+
:param split: Target split
124125
:param metric_fn: Function to compute the retrieval metric.
125126
:return: Computed metric score.
126127
"""

autointent/modules/scoring/_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def score(
3131
Evaluate the scorer on a test set and compute the specified metric.
3232
3333
:param context: Context containing test set and other data.
34+
:param split: Target split
3435
:param metric_fn: Function to compute the scoring metric.
3536
:return: Computed metric value for the test set.
3637
"""

0 commit comments

Comments
 (0)