Skip to content

Commit 146bd5a

Browse files
authored
Merge pull request #242 from KumarLabJax/move-no-prediction
Move no prediction
2 parents d5c874d + cdb15f5 commit 146bd5a

File tree

5 files changed

+64
-66
lines changed

5 files changed

+64
-66
lines changed

src/jabs/classifier/classifier.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -490,35 +490,71 @@ def sort_features_to_classify(self, features):
490490
features_sorted = features[classifier_columns]
491491
return features_sorted
492492

493-
def predict(self, features):
494-
"""predict classes for a given set of features"""
493+
def predict(
494+
self, features: pd.DataFrame, frame_indexes: np.ndarray | None = None
495+
) -> np.ndarray:
496+
"""predict classes for a given set of features
497+
498+
Args:
499+
features: DataFrame of feature data to classify
500+
frame_indexes: frame indexes to classify (default all)
501+
502+
Returns:
503+
predicted class vector
504+
"""
495505
if self._classifier_type in (ClassifierType.XGBOOST, ClassifierType.CATBOOST):
496506
with warnings.catch_warnings():
497507
warnings.simplefilter("ignore", category=FutureWarning)
498508
# XGBoost and CatBoost can handle NaN, just replace infinities
499509
result = self._classifier.predict(
500510
self.sort_features_to_classify(features.replace([np.inf, -np.inf], np.nan))
501511
)
502-
return result
503-
# Random Forest can't handle NAs & infs, so fill them with 0s
504-
return self._classifier.predict(
505-
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
506-
)
512+
else:
513+
# Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
514+
result = self._classifier.predict(
515+
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
516+
)
517+
518+
# Insert -1s into class prediction when no prediction is made
519+
if frame_indexes is not None:
520+
result_adjusted = np.full(result.shape, -1, dtype=np.int8)
521+
result_adjusted[frame_indexes] = result[frame_indexes]
522+
result = result_adjusted
523+
524+
return result
507525

508-
def predict_proba(self, features):
509-
"""predict probabilities for a given set of features"""
526+
def predict_proba(
527+
self, features: pd.DataFrame, frame_indexes: np.ndarray | None = None
528+
) -> np.ndarray:
529+
"""predict probabilities for a given set of features.
530+
531+
Args:
532+
features: DataFrame of feature data to classify
533+
frame_indexes: frame indexes to classify (default all)
534+
535+
Returns:
536+
prediction probability matrix
537+
"""
510538
if self._classifier_type in (ClassifierType.XGBOOST, ClassifierType.CATBOOST):
511539
with warnings.catch_warnings():
512540
warnings.simplefilter("ignore", category=FutureWarning)
513541
# XGBoost and CatBoost can handle NaN, just replace infinities
514542
result = self._classifier.predict_proba(
515543
self.sort_features_to_classify(features.replace([np.inf, -np.inf], np.nan))
516544
)
517-
return result
518-
# Random Forest can't handle NAs & infs, so fill them with 0s
519-
return self._classifier.predict_proba(
520-
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
521-
)
545+
else:
546+
# Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
547+
result = self._classifier.predict_proba(
548+
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
549+
)
550+
551+
# Insert 0 probabilities when no prediction is made
552+
if frame_indexes is not None:
553+
result_adjusted = np.full(result.shape, 0, dtype=np.float32)
554+
result_adjusted[frame_indexes] = result[frame_indexes]
555+
result = result_adjusted
556+
557+
return result
522558

523559
def save(self, path: Path):
524560
"""save the classifier to a file

src/jabs/project/project.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,6 @@ def save_predictions(
441441
video_name: str,
442442
predictions: dict[int, np.ndarray],
443443
probabilities: dict[int, np.ndarray],
444-
frame_indexes: dict[int, np.ndarray],
445444
behavior: str,
446445
classifier: object,
447446
) -> None:
@@ -452,25 +451,8 @@ def save_predictions(
452451
video_name: name of the video these predictions correspond to.
453452
predictions: dict mapping identity to a 1D numpy array of predicted labels.
454453
probabilities: same structure as `predictions` but with floating-point values.
455-
frame_indexes: dict mapping identity to 1D numpy array of absolute frame indices
456-
listing the frames where the identity has a valid pose (i.e., frames with a meaningful prediction).
457454
behavior: string behavior name.
458455
classifier: Classifier object used to generate the predictions.
459-
460-
Note:
461-
Currently, the classifier runs on every frame for every identity -- even when pose is invalid
462-
and features are NaN. We copy values for *only* the frames with a valid pose. This is why we
463-
index *both* the source and destination with `indexes` (an array with the absolute frame indices
464-
of frames with a valid pose), e.g.:
465-
466-
prediction_labels[identity, indexes] = predictions[video][identity][indexes]
467-
prediction_prob[identity, indexes] = probabilities[video][identity][indexes]
468-
469-
This leaves the output arrays with default values (-1 for labels, 0.0 for probabilities) for frames
470-
without pose.
471-
472-
In the future, if the upstream caller were to provide compact arrays of length `len(indexes)`
473-
instead of full-length arrays, the copy logic would need to drop the indexing on the source side.
474456
"""
475457
# set up an output filename based on the video names
476458
file_base = Path(video_name).with_suffix("").name + ".h5"
@@ -482,17 +464,10 @@ def save_predictions(
482464
)
483465
prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32)
484466

485-
# populate numpy arrays
467+
# stack the numpy arrays
486468
for identity in predictions:
487-
indexes = frame_indexes[identity]
488-
489-
# 'indexes' are absolute frame indices where this identity has a valid pose.
490-
# predictions[identity] and probabilities[identity] are full-length arrays
491-
# (len == num_frames); however, only elements at 'indexes' contain meaningful values.
492-
# We index both source and destination with 'indexes' to copy only those valid-pose frames.
493-
# If upstream ever provides compact arrays instead, drop the source-side indexing.
494-
prediction_labels[identity, indexes] = predictions[identity][indexes]
495-
prediction_prob[identity, indexes] = probabilities[identity][indexes]
469+
prediction_labels[identity] = predictions[identity]
470+
prediction_prob[identity] = probabilities[identity]
496471

497472
# write to h5 file
498473
self._prediction_manager.write_predictions(

src/jabs/scripts/classify.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,18 @@ def classify_pose(
137137
data = Classifier.combine_data(per_frame_features, window_features)
138138

139139
if data.shape[0] > 0:
140-
pred = classifier.predict(data)
141-
pred_prob = classifier.predict_proba(data)
140+
pred = classifier.predict(data, features["frame_indexes"])
141+
pred_prob = classifier.predict_proba(data, features["frame_indexes"])
142142

143143
# Keep the probability for the predicted class only.
144144
# The following code uses some
145145
# numpy magic to use the pred array as column indexes
146146
# for each row of the pred_prob array we just computed.
147147
pred_prob = pred_prob[np.arange(len(pred_prob)), pred]
148148

149-
# Only copy out predictions where there was a valid pose
150-
prediction_labels[curr_id, features["frame_indexes"]] = pred[
151-
features["frame_indexes"]
152-
]
153-
prediction_prob[curr_id, features["frame_indexes"]] = pred_prob[
154-
features["frame_indexes"]
155-
]
149+
# Copy results into results matrix
150+
prediction_labels[curr_id] = pred
151+
prediction_prob[curr_id] = pred_prob
156152
progress.update(task, advance=1)
157153

158154
print(f"Writing predictions to {out_dir}")

src/jabs/ui/central_widget.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,6 @@ def _classify_thread_complete(self, output: dict) -> None:
935935
# display the new predictions
936936
self._predictions = output["predictions"]
937937
self._probabilities = output["probabilities"]
938-
self._frame_indexes = output["frame_indexes"]
939938
self._cleanup_progress_dialog()
940939
self._cleanup_classify_thread()
941940
self.status_message.emit("Classification Complete", 3000)

src/jabs/ui/classification_thread.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def run(self) -> None:
8484
self._tasks_complete = 0
8585
current_video_predictions = {}
8686
current_video_probabilities = {}
87-
current_video_frame_indexes = {}
8887

8988
def check_termination_requested() -> None:
9089
if self._should_terminate:
@@ -104,7 +103,6 @@ def check_termination_requested() -> None:
104103
# collect predictions, probabilities, and frame indexes for each identity in the video
105104
predictions = {}
106105
probabilities = {}
107-
frame_indexes = {}
108106

109107
for identity in pose_est.identities:
110108
check_termination_requested()
@@ -136,31 +134,27 @@ def check_termination_requested() -> None:
136134
check_termination_requested()
137135
if data.shape[0] > 0:
138136
# make predictions
139-
# Note: this makes predictions for all frames in the video, even those without valid pose
140-
# We will later filter these out when saving the predictions to disk
141-
# consider changing this to only predict on frames with valid pose
142-
predictions[identity] = self._classifier.predict(data)
137+
predictions[identity] = self._classifier.predict(
138+
data, feature_values["frame_indexes"]
139+
)
143140

144141
# also get the probabilities
145-
prob = self._classifier.predict_proba(data)
142+
prob = self._classifier.predict_proba(
143+
data, feature_values["frame_indexes"]
144+
)
146145
# Save the probability for the predicted class only.
147146
# The following code uses some
148147
# numpy magic to use the _predictions array as column indexes
149148
# for each row of the 'prob' array we just computed.
150149
probabilities[identity] = prob[np.arange(len(prob)), predictions[identity]]
151-
152-
# save the indexes for the predicted frames
153-
frame_indexes[identity] = feature_values["frame_indexes"]
154150
else:
155151
predictions[identity] = np.array(0)
156152
probabilities[identity] = np.array(0)
157-
frame_indexes[identity] = np.array(0)
158153

159154
if video == self._current_video:
160155
# keep predictions for the video currently loaded in the video player
161156
current_video_predictions = predictions.copy()
162157
current_video_probabilities = probabilities.copy()
163-
current_video_frame_indexes = frame_indexes.copy()
164158

165159
# save predictions to disk
166160
self.current_status.emit("Saving Predictions")
@@ -169,7 +163,6 @@ def check_termination_requested() -> None:
169163
video,
170164
predictions,
171165
probabilities,
172-
frame_indexes,
173166
self._behavior,
174167
self._classifier,
175168
)
@@ -183,7 +176,6 @@ def check_termination_requested() -> None:
183176
{
184177
"predictions": current_video_predictions,
185178
"probabilities": current_video_probabilities,
186-
"frame_indexes": current_video_frame_indexes,
187179
}
188180
)
189181
except Exception as e:

0 commit comments

Comments
 (0)