Skip to content

Commit cb0ec0c

Browse files
committed
fix type hints
1 parent 630b054 commit cb0ec0c

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/jabs/classifier/classifier.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,13 @@ def sort_features_to_classify(self, features):
447447
features_sorted = features[classifier_columns]
448448
return features_sorted
449449

450-
def predict(self, features: dict, frame_indexes: np.ndarray | None = None) -> np.ndarray:
450+
def predict(
451+
self, features: pd.DataFrame, frame_indexes: np.ndarray | None = None
452+
) -> np.ndarray:
451453
"""predict classes for a given set of features
452454
453455
Args:
454-
features: dictionary of feature data to classify
456+
features: DataFrame of feature data to classify
455457
frame_indexes: frame indexes to classify (default all)
456458
457459
Returns:
@@ -477,11 +479,13 @@ def predict(self, features: dict, frame_indexes: np.ndarray | None = None) -> np
477479

478480
return result
479481

480-
def predict_proba(self, features: dict, frame_indexes: np.ndarray | None = None) -> np.ndarray:
482+
def predict_proba(
483+
self, features: pd.DataFrame, frame_indexes: np.ndarray | None = None
484+
) -> np.ndarray:
481485
"""predict probabilities for a given set of features.
482486
483487
Args:
484-
features: dictionary of feature data to classify
488+
features: DataFrame of feature data to classify
485489
frame_indexes: frame indexes to classify (default all)
486490
487491
Returns:

0 commit comments

Comments
 (0)