Skip to content

Commit d961704

Browse files
Merge pull request #149 from FocoosAI/fix/onnx-cpu-infer
fix(InferModel): onnx cpu inference by add "auto" device parameter
2 parents f69c6ee + ebcd463 commit d961704

File tree

8 files changed

+61
-27
lines changed

8 files changed

+61
-27
lines changed

docs/cli.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ The interface will automatically open in your default web browser, typically at
251251
| `--source` | Input source (predict only) | **Required** | `image.jpg` |
252252
| `--im-size` | Input image size | 640 | Any positive integer |
253253
| `--batch-size` | Batch size | 16 | Powers of 2 recommended |
254-
| `--device` | Compute device | `cuda` | `cuda`, `cpu`, `mps` |
254+
| `--device` | Compute device | `cuda` | `cuda`, `cpu` |
255255
| `--workers` | Data loading workers | 4 | 0-16 recommended |
256256
| `--output-dir` | Output directory | Auto-generated | Any valid path |
257257

focoos/infer/infer_model.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import os
2323
from pathlib import Path
2424
from time import perf_counter
25-
from typing import Optional, Tuple, Union
25+
from typing import Literal, Optional, Tuple, Union
2626

2727
import numpy as np
2828
import supervision as sv
@@ -41,7 +41,7 @@
4141
)
4242
from focoos.processor.processor_manager import ProcessorManager
4343
from focoos.utils.logger import get_logger
44-
from focoos.utils.system import get_cpu_name, get_device_name
44+
from focoos.utils.system import get_cpu_name, get_device_name, get_device_type
4545
from focoos.utils.vision import (
4646
annotate_frame,
4747
image_loader,
@@ -55,6 +55,7 @@ def __init__(
5555
self,
5656
model_dir: Union[str, Path],
5757
runtime_type: Optional[RuntimeType] = None,
58+
device: Literal["cuda", "cpu", "auto"] = "auto",
5859
):
5960
"""
6061
Initialize a LocalModel instance.
@@ -90,7 +91,12 @@ def __init__(
9091
# Determine runtime type and model format
9192
runtime_type = runtime_type or FOCOOS_CONFIG.runtime_type
9293
extension = ModelExtension.from_runtime_type(runtime_type)
93-
94+
if device == "auto":
95+
self.device = get_device_type()
96+
elif runtime_type == RuntimeType.ONNX_CPU:
97+
self.device = "cpu"
98+
else:
99+
self.device = device
94100
# Set model directory and path
95101
self.model_dir: Union[str, Path] = model_dir
96102
self.model_path = os.path.join(model_dir, f"model.{extension.value}")
@@ -111,7 +117,7 @@ def __init__(
111117
model_config = ConfigManager.from_dict(self.model_info.model_family, self.model_info.config)
112118
self.processor = ProcessorManager.get_processor(
113119
self.model_info.model_family, model_config, self.model_info.im_size
114-
)
120+
).eval()
115121
except Exception as e:
116122
logger.error(f"Error creating model config: {e}")
117123
raise e
@@ -123,10 +129,11 @@ def __init__(
123129

124130
# Load runtime for inference
125131
self.runtime: BaseRuntime = load_runtime(
126-
runtime_type,
127-
str(self.model_path),
128-
self.model_info,
129-
FOCOOS_CONFIG.warmup_iter,
132+
runtime_type=runtime_type,
133+
model_path=str(self.model_path),
134+
model_info=self.model_info,
135+
warmup_iter=FOCOOS_CONFIG.warmup_iter,
136+
device=self.device,
130137
)
131138

132139
def _read_model_info(self) -> ModelInfo:
@@ -175,7 +182,7 @@ def infer(
175182
t0 = perf_counter()
176183
im = image_loader(image)
177184
t1 = perf_counter()
178-
tensors, _ = self.processor.preprocess(inputs=im, device="cuda")
185+
tensors, _ = self.processor.preprocess(inputs=im, device=self.device)
179186
# logger.debug(f"Input image size: {im.shape}")
180187
t2 = perf_counter()
181188

focoos/infer/runtimes/load_runtime.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal
2+
13
from focoos.infer.runtimes.base import BaseRuntime
24
from focoos.ports import ModelInfo, OnnxRuntimeOpts, RuntimeType, TorchscriptRuntimeOpts
35
from focoos.utils.logger import get_logger
@@ -25,6 +27,7 @@ def load_runtime(
2527
model_path: str,
2628
model_info: ModelInfo,
2729
warmup_iter: int = 50,
30+
device: Literal["cuda", "cpu", "auto"] = "auto",
2831
) -> BaseRuntime:
2932
"""
3033
Creates and returns a runtime instance based on the specified runtime type.
@@ -57,7 +60,7 @@ def load_runtime(
5760
from focoos.infer.runtimes.torchscript import TorchscriptRuntime
5861

5962
opts = TorchscriptRuntimeOpts(warmup_iter=warmup_iter)
60-
return TorchscriptRuntime(model_path=model_path, opts=opts, model_info=model_info)
63+
return TorchscriptRuntime(model_path=model_path, opts=opts, model_info=model_info, device=device)
6164
else:
6265
if not ORT_AVAILABLE:
6366
logger.error(

focoos/infer/runtimes/torchscript.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from time import perf_counter
2-
from typing import Tuple, Union
2+
from typing import Literal, Tuple, Union
33

44
import numpy as np
55
import torch
66

77
from focoos.infer.runtimes.base import BaseRuntime
88
from focoos.ports import LatencyMetrics, ModelInfo, Task, TorchscriptRuntimeOpts
99
from focoos.utils.logger import get_logger
10-
from focoos.utils.system import get_cpu_name, get_device_name
10+
from focoos.utils.system import get_cpu_name, get_device_name, get_device_type
1111

1212
logger = get_logger("TorchscriptRuntime")
1313

@@ -32,8 +32,12 @@ def __init__(
3232
model_path: str,
3333
opts: TorchscriptRuntimeOpts,
3434
model_info: ModelInfo,
35+
device: Literal["cuda", "cpu", "auto"] = "auto",
3536
):
36-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37+
if device == "auto":
38+
self.device = torch.device(get_device_type())
39+
else:
40+
self.device = torch.device(device)
3741
logger.info(f"🔧 Device: {self.device}")
3842
self.opts = opts
3943
self.model_info = model_info
@@ -49,7 +53,7 @@ def __init__(
4953
)
5054
logger.info(f"⏱️ Warming up model {self.model_info.name} on {self.device}, size: {size}x{size}..")
5155
with torch.no_grad():
52-
np_image = torch.rand(1, 3, size, size, device=self.device)
56+
np_image = torch.rand(1, 3, size, size).to(self.device)
5357
for _ in range(self.opts.warmup_iter):
5458
self.model(np_image)
5559

focoos/models/focoos_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from focoos.utils.distributed.dist import launch
3333
from focoos.utils.env import TORCH_VERSION
3434
from focoos.utils.logger import get_logger
35-
from focoos.utils.system import get_cpu_name, get_device_name, get_focoos_version, get_system_info
35+
from focoos.utils.system import get_cpu_name, get_device_name, get_device_type, get_focoos_version, get_system_info
3636
from focoos.utils.vision import annotate_frame, image_loader
3737

3838
logger = get_logger("FocoosModel")
@@ -393,7 +393,7 @@ def export(
393393
runtime_type: RuntimeType = RuntimeType.TORCHSCRIPT_32,
394394
onnx_opset: int = 17,
395395
out_dir: Optional[str] = None,
396-
device: Literal["cuda", "cpu"] = "cuda",
396+
device: Literal["cuda", "cpu", "auto"] = "auto",
397397
overwrite: bool = True,
398398
image_size: Optional[Union[int, Tuple[int, int]]] = None,
399399
) -> InferModel:
@@ -416,9 +416,12 @@ def export(
416416
Raises:
417417
ValueError: If unsupported PyTorch version or export format.
418418
"""
419-
if device == "cuda" and not torch.cuda.is_available():
420-
device = "cpu"
421-
logger.warning("CUDA is not available. Using CPU for export.")
419+
if device == "auto":
420+
device = get_device_type() # type: ignore
421+
else:
422+
device = device
423+
424+
logger.info(f"🔧 Export Device: {device}")
422425
if out_dir is None:
423426
out_dir = os.path.join(MODELS_DIR, self.model_info.ref or self.model_info.name)
424427

focoos/utils/system.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import time
88
import zipfile
99
from pathlib import Path
10-
from typing import List, Optional, Union
10+
from typing import List, Literal, Optional, Union
11+
12+
import torch
1113

1214
from focoos.ports import GPUInfo
1315
from focoos.utils.distributed import comm
@@ -413,3 +415,10 @@ def get_device_name() -> str:
413415
else:
414416
cpu_name = get_cpu_name()
415417
return cpu_name if cpu_name is not None else "CPU"
418+
419+
420+
def get_device_type() -> Literal["cuda", "cpu"]:
421+
if torch.cuda.is_available():
422+
return "cuda"
423+
else:
424+
return "cpu"

tests/test_runtime.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,19 @@ def test_load_runtime(mocker: MockerFixture, tmp_path, runtime_type, expected_op
147147

148148
# assertions
149149
assert runtime is not None
150-
mock_runtime_class.assert_called_once_with(
151-
model_path,
152-
expected_opts,
153-
mock_model_metadata,
154-
)
150+
if runtime_type == RuntimeType.TORCHSCRIPT_32:
151+
mock_runtime_class.assert_called_once_with(
152+
model_path=model_path,
153+
opts=expected_opts,
154+
model_info=mock_model_metadata,
155+
device="auto",
156+
)
157+
else:
158+
mock_runtime_class.assert_called_once_with(
159+
model_path,
160+
expected_opts,
161+
mock_model_metadata,
162+
)
155163

156164

157165
def test_load_unavailable_runtime(mocker: MockerFixture):

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)