1010from ..networks .architectures .base_model import BaseModel
1111from ..utils .download_file import DownloadFile , DownloadFileInput
1212from ..utils .log import logger
13+ from ..utils .utils import mkdir
1314from .base import FileInfo , InferSession
1415
1516root_dir = Path (__file__ ).resolve ().parent .parent
1819
1920class TorchInferSession (InferSession ):
2021 def __init__ (self , cfg ) -> None :
21- self .logger = logger
22+ model_path = self ._init_model_path (cfg )
23+ arch_config = self ._load_arch_config (model_path )
2224
25+ self .predictor = self ._build_and_load_model (arch_config , model_path )
26+
27+ self ._setup_device (cfg )
28+
29+ self .predictor .eval ()
30+
31+ def _init_model_path (self , cfg ) -> Path :
2332 model_path = cfg .get ("model_path" , None )
2433 if model_path is None :
2534 model_info = self .get_model_url (
@@ -38,44 +47,69 @@ def __init__(self, cfg) -> None:
3847 file_url = default_model_url ,
3948 sha256 = model_info ["SHA256" ],
4049 save_path = model_path ,
41- logger = self . logger ,
50+ logger = logger ,
4251 )
4352 )
4453
45- self .logger .info (f"Using { model_path } " )
46- model_path = Path (model_path )
54+ logger .info (f"Using { model_path } " )
4755 self ._verify_model (model_path )
56+ return Path (model_path )
4857
58+ def _load_arch_config (self , model_path : Path ):
4959 all_arch_config = OmegaConf .load (DEFAULT_CFG_PATH )
60+
5061 file_name = model_path .stem
5162 if file_name not in all_arch_config :
5263 raise ValueError (f"architecture { file_name } is not in arch_config.yaml" )
5364
54- arch_config = all_arch_config .get (file_name )
55- self .predictor = BaseModel (arch_config )
56- self .predictor .load_state_dict (torch .load (model_path , map_location = "cpu" , weights_only = False ))
57- self .predictor .eval ()
58- self .use_gpu = False
59- self .use_npu = False
65+ return all_arch_config .get (file_name )
66+
67+ def _build_and_load_model (self , arch_config , model_path : Path ):
68+ model = BaseModel (arch_config )
69+ state_dict = torch .load (model_path , map_location = "cpu" , weights_only = False )
70+ model .load_state_dict (state_dict )
71+ return model
72+
73+ def _setup_device (self , cfg ):
74+ self .device , self .use_gpu , self .use_npu = self ._resolve_device_config (cfg )
75+
76+ if self .use_npu :
77+ self ._config_npu ()
78+
79+ self ._move_model_to_device ()
80+
81+ def _resolve_device_config (self , cfg ):
6082 if cfg .engine_cfg .use_cuda :
61- self .device = torch .device (f"cuda:{ cfg .engine_cfg .gpu_id } " )
62- self .predictor .to (self .device )
63- self .use_gpu = True
64- elif cfg .engine_cfg .use_npu :
65- try :
66- import torch_npu
67- options = {
68- # 设定算子编译的磁盘缓存模式,非必要每次重新编译
69- "ACL_OP_COMPILER_CACHE_MODE" : "enable" ,
70- # 指定缓存目录,确保路径已存在
71- "ACL_OP_COMPILER_CACHE_DIR" : "./kernel_meta" ,
72- }
73- torch_npu .npu .set_option (options )
74- except ImportError :
75- self .logger .warning ("torch_npu is not installed, options with ACL setting failed." )
76- self .device = torch .device (f"npu:{ cfg .engine_cfg .npu_id } " )
77- self .predictor .to (self .device )
78- self .use_npu = True
83+ return torch .device (f"cuda:{ cfg .engine_cfg .gpu_id } " ), True , False
84+
85+ if cfg .engine_cfg .use_npu :
86+ return torch .device (f"npu:{ cfg .engine_cfg .npu_id } " ), False , True
87+
88+ return torch .device ("cpu" ), False , False
89+
90+ def _config_npu (self ):
91+ try :
92+ import torch_npu
93+
94+ kernel_meta_dir = (root_dir / "kernel_meta" ).resolve ()
95+ mkdir (kernel_meta_dir )
96+
97+ options = {
98+ "ACL_OP_COMPILER_CACHE_MODE" : "enable" ,
99+ "ACL_OP_COMPILER_CACHE_DIR" : str (kernel_meta_dir ),
100+ }
101+ torch_npu .npu .set_option (options )
102+ except ImportError :
103+ logger .warning (
104+ "torch_npu is not installed, options with ACL setting failed. \n "
105+ "Please refer to https://github.com/Ascend/pytorch to see how to install."
106+ )
107+
108+ self .device = torch .device ("cpu" )
109+ self .use_npu = False
110+
111+ def _move_model_to_device (self ):
112+ self .predictor .to (self .device )
79113
80114 def __call__ (self , img : np .ndarray ):
81115 with torch .no_grad ():
0 commit comments