Skip to content

Commit 44af4a9

Browse files
New version.
1 parent 208cf13 commit 44af4a9

File tree

10 files changed

+45
-51
lines changed

10 files changed

+45
-51
lines changed

RELEASE_NOTES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
Reduces GPU memory usage when running inference on top of sleap.
1+
- Adds support for DeepLabCut's csv format.
2+
- Unifies multithreading configuration to be through a single argument (thread_count).

diplomat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
A tool providing multi-animal tracking capabilities on top of other Deep learning based tracking software.
33
"""
44

5-
__version__ = "0.3.7"
5+
__version__ = "0.3.8"
66
# Can be used by functions to determine if diplomat was invoked through it's CLI interface.
77
CLI_RUN = False
88

diplomat/predictors/fpe/frame_pass.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self,
3131
width: int,
3232
height: int,
33-
multi_threading_allowed: bool,
33+
thread_count: int,
3434
config: Dict[str, Any]
3535
):
3636
# Set defaults to forward iteration...
@@ -41,7 +41,8 @@ def __init__(
4141

4242
self.__width = width
4343
self.__height = height
44-
self.__multi_threading_allowed = multi_threading_allowed
44+
self.__multi_threading_allowed = thread_count > 1
45+
self.__thread_count = thread_count
4546

4647
self._config = Config(config, self.get_config_options())
4748
self._frame_data = None
@@ -58,6 +59,10 @@ def height(self) -> int:
5859
def multi_threading_allowed(self) -> bool:
5960
return self.__multi_threading_allowed
6061

62+
@property
63+
def thread_count(self) -> int:
64+
return self.__thread_count
65+
6166
def _get_step_controls(self) -> Tuple[oint, oint, oint, oint]:
6267
return self._start, self._stop, self._step, self._prior_off
6368

diplomat/predictors/fpe/frame_pass_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def __call__(
2525
self,
2626
width: int,
2727
height: int,
28-
allow_multi_threading: bool = True
28+
thread_count: int = 1
2929
) -> FramePass:
3030
return self._clazz(
31-
width, height, allow_multi_threading, self._config
31+
width, height, thread_count, self._config
3232
)
3333

3434
@classmethod

diplomat/predictors/fpe/frame_passes/cluster_frames.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,15 @@ def run_pass(
5656
if(not in_place):
5757
raise ValueError("Clustering must be done in place!")
5858

59-
thread_count = os.cpu_count() if(self.config.thread_count is None) else self.config.thread_count
60-
61-
if(self.multi_threading_allowed and (thread_count > 0)):
59+
if(self.multi_threading_allowed and (self.thread_count > 1)):
6260
from diplomat.predictors.sfpe.segmented_frame_pass_engine import PoolWithProgress
6361

6462
self._frame_data = fb_data
6563
self._frame_data.allow_pickle = False
6664

6765
iter_range = RangeSlicer(self._frame_data.frames)[self._start:self._stop:self._step]
6866

69-
with PoolWithProgress(prog_bar, process_count=thread_count, sub_ticks=1) as pool:
67+
with PoolWithProgress(prog_bar, process_count=self.thread_count, sub_ticks=1) as pool:
7068
pool.fast_map(
7169
ClusterFrames._cluster_frames,
7270
lambda i: self._get_frame(iter_range[i]),
@@ -298,12 +296,6 @@ def get_config_options(cls) -> ConfigSpec:
298296
"max_throwaway_count": (
299297
10, float, "The maximum number of clusters to throw away before giving up on clustering a given frame."
300298
),
301-
"thread_count": (
302-
None,
303-
tc.Union(tc.Literal(None), tc.RangedInteger(0, np.inf)),
304-
"The number of threads to use during processing. If None, uses os.cpu_count(). "
305-
"If 0 disables multithreading."
306-
),
307299
"clustering_mode": (
308300
ClusteringMethod.COMPLETE.name,
309301
tc.Literal(*[n.name for n in ClusteringMethod]),

diplomat/predictors/fpe/frame_passes/fix_frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def compute_scores(
542542
to_index = lambda i: slice(i * cls.SCORES_PER_CHUNK, (i + 1) * cls.SCORES_PER_CHUNK)
543543
max_dist = np.sqrt(fb_data.metadata.width ** 2 + fb_data.metadata.height ** 2)
544544

545-
if(thread_count > 0):
545+
if(thread_count > 1):
546546
from ...sfpe.segmented_frame_pass_engine import PoolWithProgress
547547
with PoolWithProgress(prog_bar, process_count=thread_count, sub_ticks=1) as pool:
548548
pool.fast_map(

diplomat/predictors/fpe/frame_passes/mit_viterbi.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class NotAPool:
1717
T = TypeVar("T")
1818
E = TypeVar("E")
1919

20+
def __init__(self, *args, **kwargs):
21+
pass
22+
2023
def __enter__(self):
2124
return self
2225

@@ -332,7 +335,7 @@ def _run_backtrace(
332335
backtrace_priors = [None for __ in range(fb_data.num_bodyparts)]
333336
backtrace_current = [None for __ in range(fb_data.num_bodyparts)]
334337

335-
with pool_cls() as pool:
338+
with pool_cls(self.thread_count) as pool:
336339
exit_prob = fb_data.metadata.enter_trans_prob
337340
transition_function = ViterbiTransitionTable(self._gaussian_table, 1, exit_prob, 1 - exit_prob)
338341

@@ -422,7 +425,7 @@ def _run_backtrace(
422425
return fb_data
423426

424427
@staticmethod
425-
def _get_pool():
428+
def _get_pool(processes):
426429
# Check globals for a pool...
427430
"""This function sets up a multiprocessing pool for parallel processing,
428431
improving the efficiency of the algorithm by allowing it to process
@@ -434,11 +437,11 @@ def _get_pool():
434437
for ctx, args in [("forkserver", {}), ("spawn", {}), ("fork", {"maxtasksperchild": 1})]:
435438
try:
436439
ctx = get_context(ctx)
437-
return ctx.Pool(**args)
440+
return ctx.Pool(processes, **args)
438441
except ValueError:
439442
continue
440443

441-
return get_context().Pool()
444+
return get_context().Pool(processes)
442445

443446
def _run_forward(
444447
self,
@@ -475,7 +478,7 @@ def _run_forward(
475478
# We only use a pool if the body part group is high enough...
476479
pool_cls = self._get_pool if(self.multi_threading_allowed and (fb_data.num_bodyparts // meta.num_outputs) > 2) else NotAPool
477480

478-
with pool_cls() as pool:
481+
with pool_cls(self.thread_count) as pool:
479482
exit_prob = fb_data.metadata.enter_trans_prob
480483
transition_func = ViterbiTransitionTable(self._gaussian_table, 1, exit_prob, 1 - exit_prob)
481484

diplomat/predictors/sfpe/segmented_frame_pass_engine.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class PoolWithProgress:
271271
def __init__(
272272
self,
273273
progress_bar: ProgressBar,
274-
process_count: int = os.cpu_count(),
274+
process_count: int,
275275
refresh_rate_seconds: float = DEF_MAX_REFRESH_RATE,
276276
sub_ticks: int = DEF_SUB_TICKS,
277277
max_worker_reuse: Optional[int] = None,
@@ -764,8 +764,10 @@ def get_maximums(
764764
return poses
765765

766766
def _run_full_passes(self, progress_bar: Optional[ProgressBar]):
767+
thread_count = self._get_thread_count()
768+
767769
for (i, frame_pass_builder) in enumerate(self.FULL_PASSES):
768-
frame_pass = frame_pass_builder(self._width, self._height, True)
770+
frame_pass = frame_pass_builder(self._width, self._height, thread_count)
769771

770772
if(progress_bar is not None):
771773
progress_bar.message(
@@ -861,7 +863,7 @@ def _run_segment_pre_initialized(
861863
frame_pass_builders: List[FramePassBuilder],
862864
width: int,
863865
height: int,
864-
allow_multi_threading: bool,
866+
thread_count: int,
865867
fix_frame_idx: int,
866868
fix_frame_score: float,
867869
progress_bar: Optional[ProgressBar] = None,
@@ -871,7 +873,7 @@ def _run_segment_pre_initialized(
871873
frame_pass_builders,
872874
width,
873875
height,
874-
allow_multi_threading,
876+
thread_count,
875877
fix_frame_idx,
876878
fix_frame_score,
877879
progress_bar,
@@ -885,7 +887,7 @@ def _run_segment(
885887
frame_pass_builders: List[FramePassBuilder],
886888
width: int,
887889
height: int,
888-
allow_multi_threading: bool,
890+
thread_count: int,
889891
fix_frame_idx: int,
890892
fix_frame_score,
891893
progress_bar: Optional[ProgressBar] = None,
@@ -925,7 +927,7 @@ def _run_segment(
925927
progress_bar.inc_rerun_counter()
926928

927929
for (i, frame_pass_builder) in enumerate(frame_pass_builders):
928-
frame_pass = frame_pass_builder(width, height, allow_multi_threading)
930+
frame_pass = frame_pass_builder(width, height, thread_count)
929931

930932
sub_frame = frame_pass.run_pass(
931933
sub_frame,
@@ -950,7 +952,14 @@ def _get_segment(self, index: int):
950952
sub_frame.frames = self._frame_holder.frames[start:end]
951953
sub_frame.metadata = self._frame_holder.metadata
952954

953-
return (sub_frame, self.SEGMENTED_PASSES, self._width, self._height, False, fix_frame - start, segment_score)
955+
return (
956+
sub_frame,
957+
self.SEGMENTED_PASSES,
958+
self._width,
959+
self._height,
960+
0,
961+
fix_frame - start,
962+
segment_score)
954963

955964
def _set_segment(self, index: int, frame_data: ForwardBackwardData):
956965
start, end, fix_frame = self._segments[index]
@@ -1047,31 +1056,20 @@ def _run_segmented_passes(
10471056

10481057
self._frame_holder.allow_pickle = False
10491058

1050-
if(thread_count <= 0):
1059+
if(thread_count <= 1):
10511060
pass_count = (len(self.SEGMENTED_PASSES) + 1) * total_segments
10521061

1053-
passes_can_use_pool = any(b.clazz.UTILIZE_GLOBAL_POOL for b in self.SEGMENTED_PASSES)
1054-
allow_multithread = self.settings.allow_pass_multithreading
1055-
10561062
wrapper_bar = NestedProgressIndicator(
10571063
progress_bar,
10581064
total=pass_count,
10591065
ticks=int(self._frame_holder.num_frames / pass_count)
10601066
)
10611067
progress_bar.message("Running on Segments...")
10621068

1063-
if(passes_can_use_pool and allow_multithread):
1064-
with PoolWithProgress.get_optimal_ctx().Pool(processes=os.cpu_count()) as pool:
1065-
FramePass.GLOBAL_POOL = AntiCloseObject(pool)
1066-
for is_pre_init, segment_idx in self._iter_run_levels(segment_idxs, run_level_data):
1067-
for idx in segment_idx:
1068-
frm, segs, width, height, __, fix_frame_idx, fix_frame_score = self._get_segment(idx)
1069-
self._run_segment(frm, segs, width, height, self.settings.allow_pass_multithreading, fix_frame_idx, fix_frame_score, wrapper_bar, is_pre_init)
1070-
else:
1071-
for is_pre_init, segment_idx in self._iter_run_levels(segment_idxs, run_level_data):
1072-
for idx in segment_idx:
1073-
frm, segs, width, height, __, fix_frame_idx, fix_frame_score = self._get_segment(idx)
1074-
self._run_segment(frm, segs, width, height, self.settings.allow_pass_multithreading, fix_frame_idx, fix_frame_score, wrapper_bar, is_pre_init)
1069+
for is_pre_init, segment_idx in self._iter_run_levels(segment_idxs, run_level_data):
1070+
for idx in segment_idx:
1071+
frm, segs, width, height, __, fix_frame_idx, fix_frame_score = self._get_segment(idx)
1072+
self._run_segment(frm, segs, width, height, thread_count, fix_frame_idx, fix_frame_score, wrapper_bar, is_pre_init)
10751073

10761074
FramePass.GLOBAL_POOL = None
10771075
else:
@@ -1534,11 +1532,6 @@ def get_settings(cls) -> ConfigSpec:
15341532
"Defaults to None, which resolves to os.cpu_count() at runtime. "
15351533
"If set to 0, disables multithreading..."
15361534
),
1537-
"allow_pass_multithreading": (
1538-
True,
1539-
bool,
1540-
"Whether or not to allow frame passes to utilize multithreading. Defaults to True."
1541-
),
15421535
"segment_size": (
15431536
200,
15441537
type_casters.RangedInteger(10, np.inf),

diplomat/predictors/supervised_sfpe/supervised_segmented_frame_pass_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,6 @@ def _partial_rerun(
504504
poses[s_i:e_i, :] = poses[s_i:e_i, seg_ord]
505505
old_poses.get_all()[:] = poses.reshape(old_poses.get_frame_count(), old_poses.get_bodypart_count() * 3)
506506

507-
508507
return (
509508
self.get_maximums(
510509
self._frame_holder,

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def _get_readme():
1919

2020
ONNX_TF_DEPS = [
2121
"h5py",
22+
"tables",
2223
"tensorflow",
2324
"tf2onnx>=1.16.1",
2425
"onnx",

0 commit comments

Comments
 (0)