Skip to content

Commit bc036a7

Browse files
committed
WIP refactor benchmarking: Introduce Engine
1 parent f2c6b11 commit bc036a7

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

dlclive/factory.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def build_runner(
1616
1717
Parameters
1818
----------
19-
model_type: str, optional
19+
model_type: str
2020
Which model to use. For the PyTorch engine, options are [`pytorch`]. For the
2121
TensorFlow engine, options are [`base`, `tensorrt`, `lite`].
2222
model_path: str, Path
@@ -33,13 +33,13 @@ def build_runner(
3333
-------
3434
3535
"""
36-
if model_type.lower() == "pytorch":
36+
if Engine.from_model_type(model_type) == Engine.PYTORCH:
3737
from dlclive.pose_estimation_pytorch.runner import PyTorchRunner
3838

3939
valid = {"device", "precision", "single_animal", "dynamic", "top_down_config"}
4040
return PyTorchRunner(model_path, **filter_keys(valid, kwargs))
4141

42-
elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"):
42+
elif Engine.from_model_type(model_type) == Engine.TENSORFLOW:
4343
from dlclive.pose_estimation_tensorflow.runner import TensorFlowRunner
4444

4545
if model_type.lower() == "tensorflow":
@@ -54,3 +54,19 @@ def build_runner(
5454
def filter_keys(valid: set[str], kwargs: dict) -> dict:
5555
"""Filters the keys in kwargs, only keeping those in valid."""
5656
return {k: v for k, v in kwargs.items() if k in valid}
57+
58+
59+
from enum import Enum
60+
61+
class Engine(Enum):
62+
TENSORFLOW = "tensorflow"
63+
PYTORCH = "pytorch"
64+
65+
@classmethod
66+
def from_model_type(cls, model_type: str) -> "Engine":
67+
if model_type.lower() == "pytorch":
68+
return cls.PYTORCH
69+
elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"):
70+
return cls.TENSORFLOW
71+
else:
72+
raise ValueError(f"Unknown model type: {model_type}")

0 commit comments

Comments
 (0)