@@ -16,7 +16,7 @@ def build_runner(
16
16
17
17
Parameters
18
18
----------
19
- model_type: str, optional
19
+ model_type: str
20
20
Which model to use. For the PyTorch engine, options are [`pytorch`]. For the
21
21
TensorFlow engine, options are [`base`, `tensorrt`, `lite`].
22
22
model_path: str, Path
@@ -33,13 +33,13 @@ def build_runner(
33
33
-------
34
34
35
35
"""
36
- if model_type . lower ( ) == "pytorch" :
36
+ if Engine . from_model_type ( model_type ) == Engine . PYTORCH :
37
37
from dlclive .pose_estimation_pytorch .runner import PyTorchRunner
38
38
39
39
valid = {"device" , "precision" , "single_animal" , "dynamic" , "top_down_config" }
40
40
return PyTorchRunner (model_path , ** filter_keys (valid , kwargs ))
41
41
42
- elif model_type . lower () in ( "tensorflow" , "base" , "tensorrt" , "lite" ) :
42
+ elif Engine . from_model_type ( model_type ) == Engine . TENSORFLOW :
43
43
from dlclive .pose_estimation_tensorflow .runner import TensorFlowRunner
44
44
45
45
if model_type .lower () == "tensorflow" :
@@ -54,3 +54,19 @@ def build_runner(
54
54
def filter_keys (valid : set [str ], kwargs : dict ) -> dict :
55
55
"""Filters the keys in kwargs, only keeping those in valid."""
56
56
return {k : v for k , v in kwargs .items () if k in valid }
57
+
58
+
59
+ from enum import Enum
60
+
61
+ class Engine (Enum ):
62
+ TENSORFLOW = "tensorflow"
63
+ PYTORCH = "pytorch"
64
+
65
+ @classmethod
66
+ def from_model_type (cls , model_type : str ) -> "Engine" :
67
+ if model_type .lower () == "pytorch" :
68
+ return cls .PYTORCH
69
+ elif model_type .lower () in ("tensorflow" , "base" , "tensorrt" , "lite" ):
70
+ return cls .TENSORFLOW
71
+ else :
72
+ raise ValueError (f"Unknown model type: { model_type } " )
0 commit comments