Skip to content

Commit 5c44a2d

Browse files
Merge pull request #3 from TravisWheelerLab/develop
Version 0.0.5
2 parents b1e93c6 + 0c29a3c commit 5c44a2d

34 files changed

+659
-150
lines changed

diplomat/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
A tool providing multi-animal tracking capabilities on top of other Deep learning based tracking software.
33
"""
44

5-
__version__ = "0.0.4"
5+
__version__ = "0.0.5"
66
# Can be used by functions to determine if diplomat was invoked through it's CLI interface.
77
CLI_RUN = False
88

@@ -61,12 +61,13 @@ def _load_frontends():
6161
if(hasattr(frontend, "__doc__")):
6262
mod.__doc__ = frontend.__doc__
6363

64-
for (name, func) in asdict(res).items():
64+
for (name, func) in res:
6565
if(not name.startswith("_")):
6666
func = replace_function_name_and_module(func, name, mod.__name__)
6767
setattr(mod, name, func)
6868
mod.__all__.append(name)
6969

7070
return frontends, loaded_funcs
7171

72+
7273
_FRONTENDS, _LOADED_FRONTENDS = _load_frontends()

diplomat/_cli_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_dynamic_cli_tree() -> dict:
4242

4343
for frontend_name, funcs in diplomat._LOADED_FRONTENDS.items():
4444
frontend_commands = {
45-
name: func for name, func in asdict(funcs).items() if(not name.startswith("_"))
45+
name: func for name, func in funcs if(not name.startswith("_"))
4646
}
4747

4848
doc_str = getattr(getattr(diplomat, frontend_name), "__doc__", None)

diplomat/core_ops.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import typing
88
from types import ModuleType
99
from diplomat.utils.tweak_ui import UIImportError
10+
from diplomat.frontends import DIPLOMATContract, DIPLOMATCommands
1011

1112

1213
class ArgumentError(CLIError):
@@ -47,14 +48,21 @@ def _get_casted_args(tc_func, extra_args, error_on_miss=True):
4748
return new_args
4849

4950

50-
def _find_frontend(config: os.PathLike, **kwargs: typing.Any) -> typing.Tuple[str, ModuleType]:
51+
def _find_frontend(
52+
contracts: Union[DIPLOMATContract, List[DIPLOMATContract]],
53+
config: os.PathLike,
54+
**kwargs: typing.Any
55+
) -> typing.Tuple[str, ModuleType]:
5156
from diplomat import _LOADED_FRONTENDS
5257

58+
contracts = [contracts] if(isinstance(contracts, DIPLOMATContract)) else contracts
59+
5360
for name, funcs in _LOADED_FRONTENDS.items():
54-
if(funcs._verifier(
61+
if(all(funcs.verify(
62+
contract=c,
5563
config=config,
5664
**kwargs
57-
)):
65+
) for c in contracts)):
5866
print(f"Frontend '{name}' selected.")
5967
return (name, funcs)
6068

@@ -178,6 +186,7 @@ def track(
178186
from diplomat import CLI_RUN
179187

180188
selected_frontend_name, selected_frontend = _find_frontend(
189+
contracts=[DIPLOMATCommands.analyze_videos, DIPLOMATCommands.analyze_videos],
181190
config=config,
182191
videos=videos,
183192
frame_stores=frame_stores,
@@ -324,7 +333,12 @@ def annotate(
324333
from diplomat import CLI_RUN
325334

326335
# Iterate the frontends, looking for one that actually matches our request...
327-
selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args)
336+
selected_frontend_name, selected_frontend = _find_frontend(
337+
contracts=DIPLOMATCommands.label_videos,
338+
config=config,
339+
videos=videos,
340+
**extra_args
341+
)
328342

329343
if(help_extra):
330344
_display_help(selected_frontend_name, "video labeling", "diplomat annotate", selected_frontend.label_videos, CLI_RUN)
@@ -350,18 +364,25 @@ def tweak(
350364
**extra_args
351365
):
352366
"""
353-
Make modifications to DIPLOMAT produced tracking results created for a video using a limited version supervised labeling UI. Allows for touching
354-
up and fixing any minor issues that may arise after tracking and saving results.
367+
Make modifications to DIPLOMAT produced tracking results created for a video using a limited version supervised
368+
labeling UI. Allows for touching up and fixing any minor issues that may arise after tracking and saving results.
355369
356-
:param config: The path to the configuration file for the project. The format of this argument will depend on the frontend.
370+
:param config: The path to the configuration file for the project. The format of this argument will depend on the
371+
frontend.
357372
:param videos: A single path or list of paths to video files to tweak the tracks of.
358-
:param help_extra: Boolean, if set to true print extra settings for the automatically selected frontend instead of showing the UI.
359-
:param extra_args: Any additional arguments (if the CLI, flags starting with '--') are passed to the automatically selected frontend.
360-
To see valid values, run tweak with extra_help flag set to true.
373+
:param help_extra: Boolean, if set to true print extra settings for the automatically selected frontend instead of
374+
showing the UI.
375+
:param extra_args: Any additional arguments (if the CLI, flags starting with '--') are passed to the automatically
376+
selected frontend. To see valid values, run tweak with extra_help flag set to true.
361377
"""
362378
from diplomat import CLI_RUN
363379

364-
selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args)
380+
selected_frontend_name, selected_frontend = _find_frontend(
381+
contracts=DIPLOMATCommands.tweak_videos,
382+
config=config,
383+
videos=videos,
384+
**extra_args
385+
)
365386

366387
if(help_extra):
367388
_display_help(selected_frontend_name, "label tweaking", "diplomat tweak", selected_frontend.tweak_videos, CLI_RUN)
@@ -404,7 +425,12 @@ def convert(
404425
"""
405426
from diplomat import CLI_RUN
406427

407-
selected_frontend_name, selected_frontend = _find_frontend(config=config, videos=videos, **extra_args)
428+
selected_frontend_name, selected_frontend = _find_frontend(
429+
contracts=DIPLOMATCommands.convert_results,
430+
config=config,
431+
videos=videos,
432+
**extra_args
433+
)
408434

409435
if(help_extra):
410436
_display_help(

diplomat/frontend_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def list_loaded_frontends():
3636
print("Description:")
3737
print(f"\t{frontend_docs[name]}")
3838
print("Supported Functions:")
39-
for k, v in asdict(funcs).items():
39+
for k, v in funcs:
4040
if(k.startswith("_")):
4141
continue
4242
print(f"\t{k}")

diplomat/frontends/__init__.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass, asdict
3+
from collections import OrderedDict
34
from diplomat.processing.type_casters import StrictCallable, PathLike, Union, List, Dict, Any, Optional, TypeCaster, NoneType
45
import typing
56

6-
77
class Select(Union):
88
def __eq__(self, other: TypeCaster):
99
if(isinstance(other, Union)):
@@ -62,29 +62,100 @@ def to_type_hint(self) -> typing.Type:
6262
)
6363

6464

65-
@dataclass(frozen=False)
66-
class DIPLOMATBaselineCommands:
65+
@dataclass(frozen=True)
66+
class DIPLOMATContract:
67+
"""
68+
Represents a 'contract'
69+
"""
70+
method_name: str
71+
method_type: StrictCallable
72+
73+
74+
class CommandManager(type):
75+
76+
__no_type_check__ = False
77+
def __new__(cls, *args, **kwargs):
78+
obj = super().__new__(cls, *args, **kwargs)
79+
80+
annotations = typing.get_type_hints(obj)
81+
82+
for name, annot in annotations.items():
83+
if(name in obj.__dict__):
84+
raise TypeError(f"Command annotation '{name}' has default value, which is not allowed.")
85+
86+
return obj
87+
88+
def __getattr__(self, item):
89+
annot = typing.get_type_hints(self)[item]
90+
return DIPLOMATContract(item, annot)
91+
92+
93+
def required(typecaster: TypeCaster) -> TypeCaster:
94+
typecaster._required = True
95+
return typecaster
96+
97+
98+
class DIPLOMATCommands(metaclass=CommandManager):
6799
"""
68100
The baseline set of functions each DIPLOMAT backend must implement. Backends can add additional commands
69-
by extending this base class...
101+
by passing the methods to this classes constructor.
70102
"""
71-
_verifier: VerifierFunction
103+
_verifier: required(VerifierFunction)
72104
analyze_videos: AnalyzeVideosFunction(NoneType)
73105
analyze_frames: AnalyzeFramesFunction(NoneType)
74106
label_videos: LabelVideosFunction(NoneType)
75107
tweak_videos: LabelVideosFunction(NoneType)
76108
convert_results: ConvertResultsFunction(NoneType)
77109

78-
def __post_init__(self):
110+
def __init__(self, **kwargs):
111+
missing = object()
112+
self._commands = OrderedDict()
113+
79114
annotations = typing.get_type_hints(type(self))
80115

81-
for name, value in asdict(self).items():
82-
annot = annotations.get(name, None)
116+
for name, annot in annotations.items():
117+
value = kwargs.get(name, missing)
83118

119+
if(value is missing):
120+
if(getattr(annot, "_required", False)):
121+
raise ValueError(f"Command '{name}' is required, but was not provided.")
122+
continue
84123
if(annot is None or (not isinstance(annot, TypeCaster))):
85124
raise TypeError("DIPLOMAT Command Struct can only contain typecaster types.")
86125

87-
setattr(self, name, annot(value))
126+
self._commands[name] = annot(value)
127+
128+
for name, value in kwargs.items():
129+
if(name not in annotations):
130+
self._commands[name] = value
131+
132+
def __iter__(self):
133+
return iter(self._commands.items())
134+
135+
def __getattr__(self, item: str):
136+
return self._commands.get(item)
137+
138+
def verify(self, contract: DIPLOMATContract, config: Union[List[PathLike], PathLike], **kwargs: Any) -> bool:
139+
"""
140+
Verify this backend can handle the provided command type, config file, and arguments.
141+
142+
:param contract: The contract for the command. Includes the name of the method and the type of the method,
143+
which will typically be a strict callable.
144+
:param config: The configuration file, checks if the backend can handle this configuration file.
145+
:param kwargs: Any additional arguments to pass to the backends verifier.
146+
147+
:return: A boolean, True if the backend can handle the provided command and arguments, otherwise False.
148+
"""
149+
if(contract.method_name in self._commands):
150+
func = self._commands[contract.method_name]
151+
try:
152+
contract.method_type(func)
153+
except Exception:
154+
return False
155+
156+
return self._verifier(config, **kwargs)
157+
158+
return False
88159

89160

90161
class DIPLOMATFrontend(ABC):
@@ -93,7 +164,7 @@ class DIPLOMATFrontend(ABC):
93164
"""
94165
@classmethod
95166
@abstractmethod
96-
def init(cls) -> typing.Optional[DIPLOMATBaselineCommands]:
167+
def init(cls) -> typing.Optional[DIPLOMATCommands]:
97168
"""
98169
Attempt to initialize the frontend, returning a list of api functions. If the backend can't initialize due to missing imports/requirements,
99170
this function should return None.

diplomat/frontends/csv/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Optional
2+
from diplomat.frontends import DIPLOMATFrontend, DIPLOMATCommands
3+
4+
5+
class DEEPLABCUTFrontend(DIPLOMATFrontend):
6+
"""
7+
The CSV frontend for DIPLOMAT. Contains functions for running some DIPLOMAT operations on csv trajectory files.
8+
Supports video creation, and tweak UI commands.
9+
"""
10+
@classmethod
11+
def init(cls) -> Optional[DIPLOMATCommands]:
12+
try:
13+
from diplomat.frontends.csv._verify_func import _verify
14+
from diplomat.frontends.csv.label_videos import label_videos
15+
from diplomat.frontends.csv.tweak_results import tweak_videos
16+
except ImportError:
17+
return None
18+
19+
return DIPLOMATCommands(
20+
_verifier=_verify,
21+
label_videos=label_videos,
22+
tweak_videos=tweak_videos
23+
)
24+
25+
@classmethod
26+
def get_package_name(cls) -> str:
27+
return "csv"
28+
29+
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from diplomat.processing.type_casters import typecaster_function, Union, List, PathLike
2+
from .csv_utils import _fix_paths, _header_check
3+
4+
5+
@typecaster_function
6+
def _verify(
7+
config: Union[List[PathLike], PathLike],
8+
**kwargs
9+
) -> bool:
10+
if("videos" not in kwargs):
11+
return False
12+
13+
try:
14+
config, videos = _fix_paths(config, kwargs["videos"])
15+
return all(_header_check(c) for c in config)
16+
except (IOError, ValueError):
17+
return False
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
3+
def _header_check(csv):
4+
with open(csv, "r") as csv_handle:
5+
first_lines = [csv_handle.readline().strip("\n").split(",") for i in range(3)]
6+
7+
header_cols = len(first_lines[0])
8+
9+
if(not all(header_cols == len(line) for line in first_lines)):
10+
return False
11+
12+
last_header_line = first_lines[-1]
13+
last_line_exp = ["x", "y", "likelihood"] * (len(last_header_line) // 3)
14+
15+
if(last_header_line != last_line_exp):
16+
return False
17+
18+
return True
19+
20+
21+
def _fix_paths(csvs, videos):
22+
csvs = csvs if(isinstance(csvs, (tuple, list))) else [csvs]
23+
videos = videos if(isinstance(videos, (tuple, list))) else [videos]
24+
25+
if(len(csvs) == 1):
26+
csvs = csvs * len(videos)
27+
if(len(videos) == 1):
28+
videos = videos * len(csvs)
29+
30+
if(len(videos) != len(csvs)):
31+
raise ValueError("Number of videos and csv files passes don't match!")
32+
33+
return csvs, videos

0 commit comments

Comments
 (0)