Skip to content

Commit bdca370

Browse files
authored
update numpy typing (#188)
1 parent 791669e commit bdca370

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

autointent/modules/decision/_adaptive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | Non
148148
res = (scores >= thresh[:, None]).astype(int)
149149
if tags:
150150
res = apply_tags(res, scores, tags)
151-
y_pred: list[MultiLabel] = res.tolist() # type: ignore[assignment]
151+
y_pred: list[MultiLabel] = res.tolist()
152152
return [lab if sum(lab) > 0 else None for lab in y_pred]
153153

154154

autointent/modules/decision/_jinoos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def predict(self, scores: npt.NDArray[Any]) -> list[int | None]:
113113
if scores.shape[1] != self._n_classes:
114114
raise MismatchNumClassesError
115115
pred_classes, best_scores = _predict(scores)
116-
y_pred: list[int] = _detect_oos(pred_classes, best_scores, self._thresh).tolist() # type: ignore[assignment]
116+
y_pred: list[int] = _detect_oos(pred_classes, best_scores, self._thresh).tolist()
117117
return [lab if lab != -1 else None for lab in y_pred]
118118

119119
@staticmethod

autointent/modules/decision/_threshold.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def multiclass_predict(scores: npt.NDArray[Any], thresh: float | npt.NDArray[Any
169169
thresh_selected = thresh[pred_classes]
170170
pred_classes[best_scores < thresh_selected] = -1 # out of scope
171171

172-
y_pred: list[int] = pred_classes.tolist() # type: ignore[assignment]
172+
y_pred: list[int] = pred_classes.tolist()
173173
return [lab if lab != -1 else None for lab in y_pred]
174174

175175

@@ -191,5 +191,5 @@ def multilabel_predict(
191191
res = (scores >= thresh).astype(int) if isinstance(thresh, float) else (scores >= thresh[None, :]).astype(int)
192192
if tags:
193193
res = apply_tags(res, scores, tags)
194-
y_pred: list[MultiLabel] = res.tolist() # type: ignore[assignment]
194+
y_pred: list[MultiLabel] = res.tolist()
195195
return [lab if sum(lab) > 0 else None for lab in y_pred]

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _get_cross_encoder_scores(self, utterances: list[str], candidates: list[list
180180

181181
flattened_cross_encoder_scores: npt.NDArray[np.float64] = self._cross_encoder.predict(flattened_text_pairs)
182182
return [
183-
flattened_cross_encoder_scores[i : i + self.k].tolist() # type: ignore[misc]
183+
flattened_cross_encoder_scores[i : i + self.k].tolist()
184184
for i in range(0, len(flattened_cross_encoder_scores), self.k)
185185
]
186186

0 commit comments

Comments
 (0)