2222import os
2323from pathlib import Path
2424from time import perf_counter
25- from typing import Optional , Tuple , Union
25+ from typing import Literal , Optional , Tuple , Union
2626
2727import numpy as np
2828import supervision as sv
4141)
4242from focoos .processor .processor_manager import ProcessorManager
4343from focoos .utils .logger import get_logger
44- from focoos .utils .system import get_cpu_name , get_device_name
44+ from focoos .utils .system import get_cpu_name , get_device_name , get_device_type
4545from focoos .utils .vision import (
4646 annotate_frame ,
4747 image_loader ,
@@ -55,6 +55,7 @@ def __init__(
5555 self ,
5656 model_dir : Union [str , Path ],
5757 runtime_type : Optional [RuntimeType ] = None ,
58+ device : Literal ["cuda" , "cpu" , "auto" ] = "auto" ,
5859 ):
5960 """
6061 Initialize a LocalModel instance.
@@ -90,7 +91,12 @@ def __init__(
9091 # Determine runtime type and model format
9192 runtime_type = runtime_type or FOCOOS_CONFIG .runtime_type
9293 extension = ModelExtension .from_runtime_type (runtime_type )
93-
94+ if device == "auto" :
95+ self .device = get_device_type ()
96+ elif runtime_type == RuntimeType .ONNX_CPU :
97+ self .device = "cpu"
98+ else :
99+ self .device = device
94100 # Set model directory and path
95101 self .model_dir : Union [str , Path ] = model_dir
96102 self .model_path = os .path .join (model_dir , f"model.{ extension .value } " )
@@ -111,7 +117,7 @@ def __init__(
111117 model_config = ConfigManager .from_dict (self .model_info .model_family , self .model_info .config )
112118 self .processor = ProcessorManager .get_processor (
113119 self .model_info .model_family , model_config , self .model_info .im_size
114- )
120+ ). eval ()
115121 except Exception as e :
116122 logger .error (f"Error creating model config: { e } " )
117123 raise e
@@ -123,10 +129,11 @@ def __init__(
123129
124130 # Load runtime for inference
125131 self .runtime : BaseRuntime = load_runtime (
126- runtime_type ,
127- str (self .model_path ),
128- self .model_info ,
129- FOCOOS_CONFIG .warmup_iter ,
132+ runtime_type = runtime_type ,
133+ model_path = str (self .model_path ),
134+ model_info = self .model_info ,
135+ warmup_iter = FOCOOS_CONFIG .warmup_iter ,
136+ device = self .device ,
130137 )
131138
132139 def _read_model_info (self ) -> ModelInfo :
@@ -175,7 +182,7 @@ def infer(
175182 t0 = perf_counter ()
176183 im = image_loader (image )
177184 t1 = perf_counter ()
178- tensors , _ = self .processor .preprocess (inputs = im , device = "cuda" )
185+ tensors , _ = self .processor .preprocess (inputs = im , device = self . device )
179186 # logger.debug(f"Input image size: {im.shape}")
180187 t2 = perf_counter ()
181188
0 commit comments