Skip to content

Commit 8b94741

Browse files
committed
fix typing
1 parent a74e5dd commit 8b94741

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

autointent/metrics/retrieval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,14 @@ def ignore_oos(func: RetrievalMetricFn) -> RetrievalMetricFn:
114114
"""Ignore OOS in metrics calculation (decorator)."""
115115

116116
@wraps(func)
117-
def wrapper(query_labels: list[Any | None], candidates_labels: list[Any]) -> float:
117+
def wrapper(query_labels: LABELS_VALUE_TYPE, candidates_labels: CANDIDATE_TYPE) -> float:
118118
query_labels_filtered = [lab for lab in query_labels if lab is not None]
119119
candidates_labels_filtered = [
120120
cand for cand, lab in zip(candidates_labels, query_labels, strict=True) if lab is not None
121121
]
122-
return func(query_labels_filtered, candidates_labels_filtered)
122+
return func(query_labels_filtered, candidates_labels_filtered) # type: ignore[arg-type]
123123

124-
return wrapper
124+
return wrapper # type: ignore[return-value]
125125

126126

127127
@ignore_oos

autointent/metrics/scoring.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from functools import wraps
5-
from typing import Any, Protocol
5+
from typing import Protocol
66

77
import numpy as np
88
from sklearn.metrics import coverage_error, label_ranking_average_precision_score, label_ranking_loss, roc_auc_score
@@ -34,10 +34,10 @@ def ignore_oos(func: ScoringMetricFn) -> ScoringMetricFn:
3434
"""Ignore OOS in metrics calculation (decorator)."""
3535

3636
@wraps(func)
37-
def wrapper(labels: list[Any | None], scores: list[Any]) -> float:
37+
def wrapper(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
3838
labels_filtered = [lab for lab in labels if lab is not None]
3939
scores_filtered = [score for score, lab in zip(scores, labels, strict=True) if lab is not None]
40-
return func(labels_filtered, scores_filtered)
40+
return func(labels_filtered, scores_filtered) # type: ignore[arg-type]
4141

4242
return wrapper
4343

0 commit comments

Comments
 (0)