Skip to content

Commit efd7b39

Browse files
Merge pull request #6 from TravisWheelerLab/update-pose-labelers
Draft: Branch for making changed to the pose labelers
2 parents dd406d1 + 20e32c0 commit efd7b39

File tree

8 files changed

+187
-10
lines changed

8 files changed

+187
-10
lines changed

diplomat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
A tool providing multi-animal tracking capabilities on top of other Deep learning based tracking software.
33
"""
44

5-
__version__ = "0.1.0"
5+
__version__ = "0.1.1"
66
# Can be used by functions to determine if diplomat was invoked through it's CLI interface.
77
CLI_RUN = False
88

diplomat/predictors/fpe/frame_passes/mit_viterbi.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def __init__(
7272
self._enter_exit_prob = to_log_space(enter_exit_prob)
7373
self._enter_stay_prob = to_log_space(enter_stay_prob)
7474

75+
"""The ViterbiTransitionTable class is used to manage transition probabilities,
76+
including those modified by the dominance relationship and the "flat-topped" Gaussian distribution."""
77+
7578
@staticmethod
7679
def _is_enter_state(coords: Coords) -> bool:
7780
return len(coords[0]) == 1 and np.isneginf(coords[0][0])
@@ -132,6 +135,7 @@ def _init_gaussian_table(self, metadata: AttributeDict):
132135
else:
133136
self._scaled_std = (std if (std != "auto") else 1) / metadata.down_scaling
134137

138+
#flat topped gaussian
135139
self._flatten_std = None if (conf.gaussian_plateau is None) else self._scaled_std * conf.gaussian_plateau
136140
self._gaussian_table = norm(fpe_math.gaussian_table(
137141
self.height, self.width, self._scaled_std, conf.amplitude,
@@ -153,6 +157,27 @@ def _init_gaussian_table(self, metadata: AttributeDict):
153157
metadata.include_soft_domination = self.config.include_soft_domination
154158

155159
def _init_skeleton(self, data: ForwardBackwardData):
160+
"""If skeleton data is available, this function initializes the skeleton tables,
161+
which are used to enhance tracking by considering the structural
162+
relationships between different body parts.
163+
164+
The _skeleton_tables is a StorageGraph object that stores the relationship between different body parts
165+
as defined in the skeleton data from the metadata. Each entry in this table represents a connection
166+
between two body parts (nodes) and contains the statistical data (bin_val, freq, avg) related to that connection.
167+
This data is used to enhance tracking accuracy by considering the structural relationships between body parts.
168+
169+
Specifically, it stores:
170+
# - The names of the nodes (body parts) involved in the skeleton structure.
171+
# - A matrix for each pair of connected nodes, which is computed based on the skeleton formula. This matrix
172+
# represents the likelihood of transitioning from one body part to another, taking into account the average
173+
# distance and frequency of such transitions as observed in the training data.
174+
# - The configuration parameters used for calculating these matrices, which include adjustments for log space
175+
# calculations and other statistical considerations.
176+
# This structure is crucial for the Viterbi algorithm to accurately model the movement and relationships
177+
# between different parts of the body during tracking.
178+
179+
"""
180+
156181
if("skeleton" in data.metadata):
157182
meta = data.metadata
158183
self._skeleton_tables = StorageGraph(meta.skeleton.node_names())
@@ -188,6 +213,10 @@ def run_pass(
188213
in_place: bool = True,
189214
reset_bar: bool = True
190215
) -> ForwardBackwardData:
216+
"""
217+
This is the main function that orchestrates the forward and backward passes of the Viterbi algorithm.
218+
It initializes the necessary tables and states, then runs the forward pass to calculate probabilities,
219+
followed by a backtrace to determine the most probable paths."""
191220
with warnings.catch_warnings():
192221
warnings.filterwarnings("ignore")
193222
if("fixed_frame_index" not in fb_data.metadata):
@@ -343,6 +372,9 @@ def _run_backtrace(
343372
@staticmethod
344373
def _get_pool():
345374
# Check globals for a pool...
375+
"""This function sets up a multiprocessing pool for parallel processing,
376+
improving the efficiency of the algorithm by allowing it to process
377+
multiple parts of the frame or multiple frames simultaneously."""
346378
if(FramePass.GLOBAL_POOL is not None):
347379
return FramePass.GLOBAL_POOL
348380

@@ -431,6 +463,42 @@ def _compute_backtrace_step(
431463
soft_dom_weight: float = 0,
432464
skeleton_weight: float = 0
433465
) -> List[np.ndarray]:
466+
"""This method is responsible for computing the transition probabilities from the prior maximum locations
467+
(highest probability states) of all body parts in the prior frame to the current frame's states.
468+
It's where the algorithm determines the most probable path that leads to each pixel
469+
in the current frame based on the accumulated probabilities from previous frames.
470+
471+
Parameters
472+
prior: A list of lists containing tuples.
473+
Each tuple represents the probability and coordinates (x, y) of the prior maximum locations
474+
for all body parts in the prior frame.
475+
This data structure allows the method to consider multiple potential origins for each body part's current position.
476+
477+
current: A list of tuples containing the probability and coordinates (x, y) of the current frame's states
478+
This represents the possible current positions and their associated probabilities.
479+
480+
bp_idx: The index of the body part being processed. This is used to identify which part of the data corresponds to the current body part in multi-body part tracking scenarios.
481+
482+
metadata: The metadata from the ForwardBackwardData object.
483+
An AttributeDict containing metadata that might be necessary for the computation, such as configuration parameters or additional data needed for probability calculations.
484+
485+
transition_function: A function or callable object that calculates the transition probabilities between states. This is crucial for determining how likely it is to move from one state to another.
486+
487+
resist_transition_function: A function or callable object that calculates the resistance to transitioning between states.
488+
Similar to transition_function, but used for calculating resistive transitions, which might be part of handling interactions between different tracked objects or body parts.
489+
490+
skeleton_table: A StorageGraph object that stores the relationship between different body parts as defined in the skeleton data from the metadata.
491+
An optional parameter that, if provided, contains skeleton information that can be used to enhance the tracking by considering the structural relationships between different body parts.
492+
493+
soft_dom_weight: A float representing the weight of the soft domination factor.
494+
495+
skeleton_weight: A float representing the weight of the skeleton factor.
496+
497+
"""
498+
499+
# If skeleton information is available, the method first computes the influence of skeletal connections
500+
# on the transition probabilities.
501+
# This involves considering the structural relationships between body parts and adjusting probabilities accordingly.
434502
skel_res = cls._compute_from_skeleton(
435503
prior,
436504
current,
@@ -439,6 +507,11 @@ def _compute_backtrace_step(
439507
skeleton_table
440508
)
441509

510+
#The method then calculates the effect of soft domination,
511+
# which is a technique used to handle the dominance relationship between different paths.
512+
# This step adjusts the probabilities to favor more likely paths and suppress less likely ones,
513+
# based on the configured soft domination weight.
514+
442515
from_soft_dom = cls._compute_soft_domination(
443516
prior,
444517
current,
@@ -447,12 +520,22 @@ def _compute_backtrace_step(
447520
resist_transition_function,
448521
)
449522

523+
#The core of the method involves calculating the transition probabilities from the prior states to the current states.
524+
# This is done using the transition_function, which takes into account the distances between states and other factors
525+
# to determine how likely it is to transition from one state to another.
450526
trans_res = cls.log_viterbi_between(
451527
current,
452528
prior[bp_idx],
453529
transition_function
454530
)
455531

532+
#The calculated probabilities from the skeleton influence, soft domination,
533+
# and direct transitions are then combined to determine the overall probability of transitioning
534+
# to each current state from the prior states.
535+
# This involves weighting each component according to the configured weights and summing them up to get the final probabilities.
536+
537+
#Normalization: Finally, the probabilities are normalized to ensure they are within a valid range
538+
# and to facilitate comparison between different paths.
456539
return norm_together([
457540
t + s * skeleton_weight + d * soft_dom_weight for t, s, d in zip(trans_res, skel_res, from_soft_dom)
458541
])
@@ -527,6 +610,8 @@ def _compute_from_skeleton(
527610
merge_internal: Callable[[np.ndarray, int], np.ndarray] = np.max,
528611
merge_results: bool = True
529612
) -> Union[List[Tuple[int, List[NumericArray]]], List[NumericArray]]:
613+
614+
#TODO: Add docstring and notes in coda
530615
if(skeleton_table is None):
531616
return [0] * len(current_data) if(merge_results) else []
532617

@@ -597,6 +682,45 @@ def _compute_soft_domination(
597682
merge_internal: Callable[[np.ndarray, int], np.ndarray] = np.max,
598683
merge_results: bool = True
599684
) -> Union[List[Tuple[int, List[NumericArray]]], List[NumericArray]]:
685+
"""
686+
Computes the soft domination for a given body part across frames, considering prior and current data.
687+
688+
This method calculates the soft domination values by comparing the probabilities of a body part being in
689+
a certain state in the current frame against its probabilities in the prior frames. It uses a transition
690+
function to determine the likelihood of transitioning from each state in the prior frames to each state in
691+
the current frame. The results are merged using specified merging functions to find the most probable state
692+
transitions. This method can optionally merge the results across all body parts to find the overall most
693+
probable states.
694+
695+
Parameters:
696+
- prior: Union[List[ForwardBackwardFrame], List[List[Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]]]]
697+
The prior frame data or computed probabilities and coordinates for each body part.
698+
- current_data: List[Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]]
699+
The current frame data including probabilities and coordinates for each body part.
700+
- bp_idx: int
701+
The index of the body part being processed.
702+
- metadata: AttributeDict
703+
Metadata containing configuration and state information for the current processing.
704+
- transition_func: Optional[TransitionFunction]
705+
The function used to compute the transition probabilities between states.
706+
- merge_arrays: Callable[[Iterable[np.ndarray]], np.ndarray]
707+
A function to merge arrays of probabilities from different transitions.
708+
- merge_internal: Callable[[np.ndarray, int], np.ndarray]
709+
A function to merge probabilities within a single transition.
710+
- merge_results: bool
711+
A flag indicating whether to merge the results across all body parts.
712+
713+
Returns:
714+
- Union[List[Tuple[int, List[NumericArray]]], List[NumericArray]]:
715+
The computed soft domination values for the specified body part, either as a list of numeric arrays
716+
(if merge_results is False) or as a list of tuples containing the body part index and the list of
717+
numeric arrays (if merge_results is True).
718+
719+
This method is crucial for optimizing the Viterbi path selection by considering not only the most probable
720+
paths but also how these paths compare when considering potential transitions from prior states. It helps
721+
in refining the selection of paths that are not only probable in isolation but also in the context of the
722+
sequence of frames being analyzed.
723+
"""
600724
if(transition_func is None or metadata.num_outputs <= 1):
601725
return [0] * len(current_data) if(merge_results) else []
602726

@@ -669,6 +793,12 @@ def _compute_normal_frame(
669793
soft_dom_weight: float = 0,
670794
skeleton_weight: float = 0
671795
) -> List[ForwardBackwardFrame]:
796+
797+
"""processes a single frame in the context of tracking multiple body parts or individuals,
798+
calculating the probabilities of each body part being in each position based on prior information,
799+
current observations, and various transition models.
800+
It integrates several key concepts, including handling occlusions, leveraging skeleton information,
801+
and applying soft domination to refine the tracking process."""
672802
group_range = range(
673803
bp_group * metadata.num_outputs,
674804
(bp_group + 1) * metadata.num_outputs
@@ -814,6 +944,21 @@ def log_viterbi_between(
814944
merge_arrays: Callable[[Iterable[np.ndarray]], np.ndarray] = np.maximum.reduce,
815945
merge_internal: Callable[[np.ndarray, int], np.ndarray] = np.nanmax
816946
) -> List[np.ndarray]:
947+
"""
948+
This method calculates the transition probabilities between the prior and current data points for each body part.
949+
It utilizes a transition function to compute the probabilities of moving from each prior state to each current state.
950+
The method then merges these probabilities across all body parts to determine the most likely transitions.
951+
952+
Parameters:
953+
- current_data: A sequence of tuples containing the current probabilities and coordinates for each body part.
954+
- prior_data: A sequence of tuples containing the prior probabilities and coordinates for each body part.
955+
- transition_function: A callable that computes the transition probabilities between prior and current states.
956+
- merge_arrays: A callable that merges arrays of probabilities across all body parts.
957+
- merge_internal: A callable that merges probabilities within each body part.
958+
959+
Returns:
960+
A list of numpy arrays representing the merged transition probabilities for each body part.
961+
"""
817962
return [
818963
merge_arrays([
819964
merge_internal(
@@ -851,6 +996,11 @@ def generate_occluded(
851996
@classmethod
852997
def get_config_options(cls) -> ConfigSpec:
853998
# Class to enforce that probabilities are between 0 and 1....
999+
"""This function returns a dictionary of configuration options that can be adjusted to
1000+
customize the behavior of the algorithm.
1001+
These options include parameters for the Gaussian distribution,
1002+
probabilities for obscured and edge states,
1003+
and weights for the dominance relationship and skeleton data."""
8541004
return {
8551005
"standard_deviation": (
8561006
"auto", tc.Union(float, tc.Literal("auto")),

diplomat/predictors/supervised_fpe/labelers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def predict_location(
5757
frame = self._frame_engine.frame_data.frames[frame_idx][bp_idx]
5858

5959
if(x is None):
60+
#should we be returning this prob value or the probability value?
6061
x, y, prob = self._frame_engine.scmap_to_video_coord(
6162
*self._frame_engine.get_maximum_with_defaults(frame),
6263
meta.down_scaling
@@ -206,7 +207,7 @@ def predict_location(
206207
bp_idx: int,
207208
x: float,
208209
y: float,
209-
probability: float
210+
probability: float,
210211
) -> Tuple[Any, Tuple[float, float, float]]:
211212
info = self._settings.get_values()
212213
user_amp = info.user_input_strength / 1000
@@ -293,7 +294,10 @@ def pose_change(self, new_state: Any) -> Any:
293294
)
294295
new_data.pack(*[np.array([item]) for item in [y, x, prob, off_x, off_y]])
295296
else:
296-
new_data = suggested_frame.src_data
297+
y, x, prob, x_offset, y_offset = suggested_frame.src_data.unpack()
298+
max_prob_idx = np.argmax(prob)
299+
new_data = SparseTrackingData()
300+
new_data.pack(*[np.array([item]) for item in [y[max_prob_idx], x[max_prob_idx], 1, x_offset[max_prob_idx], y_offset[max_prob_idx]]])
297301

298302
new_frame = ForwardBackwardFrame()
299303
new_frame.orig_data = new_data
@@ -481,7 +485,10 @@ def pose_change(self, new_state: Any) -> Any:
481485
)
482486
new_data.pack(*[np.array([item]) for item in [y, x, prob, off_x, off_y]])
483487
else:
484-
new_data = suggested_frame.src_data
488+
y, x, prob, x_offset, y_offset = suggested_frame.src_data.unpack()
489+
max_prob_idx = np.argmax(prob)
490+
new_data = SparseTrackingData()
491+
new_data.pack(*[np.array([item]) for item in [y[max_prob_idx], x[max_prob_idx], 1, x_offset[max_prob_idx], y_offset[max_prob_idx]]])
485492

486493
new_frame = ForwardBackwardFrame()
487494
new_frame.orig_data = new_data

diplomat/predictors/supervised_fpe/supervised_frame_pass_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def on_end(self, progress_bar: ProgressBar) -> Union[None, Pose]:
6767
self._get_names(),
6868
self.video_metadata,
6969
self._get_crop_box(),
70-
[Approximate(self), ApproximateSourceOnly(self), Point(self), NearestPeakInSource(self)],
70+
[Approximate(self), Point(self), NearestPeakInSource(self), ApproximateSourceOnly(self)],
7171
[EntropyOfTransitions(self), MaximumJumpInStandardDeviations(self)],
7272
None,
7373
list(range(1, self.num_outputs + 1)) * (self._num_total_bp // self.num_outputs)

diplomat/predictors/supervised_sfpe/supervised_segmented_frame_pass_engine.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,15 @@ def _partial_rerun(
516516
old_poses: Pose,
517517
progress_bar: ProgressBar
518518
) -> Tuple[Pose, Iterable[int]]:
519+
520+
#TODO : delete below lines, not doing as expected
521+
522+
# # For each changed frame and each body part, take the maximum probability coordinates and set them to one
523+
# for (frame_idx, bp_idx), frame in changed_frames.items():
524+
# max_prob_coord = np.unravel_index(frame.frame_probs.argmax(), frame.frame_probs.shape)
525+
# new_frame_probs = np.zeros_like(frame.frame_probs) #copy because this is read only
526+
# new_frame_probs[max_prob_coord] = 1
527+
# frame.frame_probs = new_frame_probs
519528
# Determine what segments have been manipulated...
520529
segment_indexes = sorted({np.searchsorted(self._segments[:, 1], f_i, "right") for f_i, b_i in changed_frames})
521530

@@ -531,7 +540,8 @@ def _partial_rerun(
531540
for (s_i, e_i, f_i), seg_ord in zip(self._segments, self._segment_bp_order):
532541
poses[s_i:e_i, :] = poses[s_i:e_i, seg_ord]
533542
old_poses.get_all()[:] = poses.reshape(old_poses.get_frame_count(), old_poses.get_bodypart_count() * 3)
534-
543+
544+
535545
return (
536546
self.get_maximums(
537547
self._frame_holder,
@@ -672,7 +682,7 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
672682
self._get_names(),
673683
self.video_metadata,
674684
self._get_crop_box(),
675-
[Approximate(self), ApproximateSourceOnly(self), Point(self), NearestPeakInSource(self)],
685+
[Approximate(self), Point(self), NearestPeakInSource(self), ApproximateSourceOnly(self)],
676686
[EntropyOfTransitions(self), MaximumJumpInStandardDeviations(self)],
677687
None,
678688
list(range(1, self.num_outputs + 1)) * (self._num_total_bp // self.num_outputs),

diplomat/wx_gui/fpe_editor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __init__(
288288
bp_names=names,
289289
labeling_modes=labeling_modes,
290290
group_list=part_groups,
291+
# skeleton_info = self.skeleton_info
291292
**ps
292293
)
293294
self.video_controls = VideoController(self._sub_panel, video_player=self.video_player.video_viewer)
@@ -336,6 +337,7 @@ def __init__(
336337

337338
self.video_controls.Bind(PointViewNEdit.EVT_FRAME_CHANGE, self._on_frame_chg)
338339

340+
339341
def _on_close_caller(self, event: wx.CloseEvent):
340342
self._on_close(event, self._was_save_button_flag)
341343
self._was_save_button_flag = False

0 commit comments

Comments
 (0)