Skip to content

Commit 7671ca5

Browse files
committed
add onnx utils
1 parent 96b4ca6 commit 7671ca5

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""Utilities for loading ONNX surrogate model files."""
6+
7+
from __future__ import annotations
8+
9+
import sys
10+
from pathlib import Path
11+
from typing import Optional
12+
13+
# Data package name for ONNX surrogate models
14+
_ONNX_PACKAGE = "surfaces_onnx_files"
15+
16+
# Local models directory (for newly trained models)
17+
_LOCAL_MODELS_DIR = Path(__file__).parent / "models"
18+
19+
20+
def _get_trained_model_path(filename: str) -> Optional[Path]:
21+
"""Check if file exists in the local models directory (for newly trained models)."""
22+
local_path = _LOCAL_MODELS_DIR / filename
23+
if local_path.exists():
24+
return local_path
25+
return None
26+
27+
28+
def _get_local_onnx_path(filename: str) -> Optional[Path]:
29+
"""Check if file exists in the local data package directory (for development)."""
30+
# This file is at src/surfaces/_surrogates/_onnx_utils.py
31+
# Need 4 parents to get to repo root
32+
repo_root = Path(__file__).parent.parent.parent.parent
33+
local_path = (
34+
repo_root
35+
/ "data-packages"
36+
/ "surfaces-onnx-files"
37+
/ "src"
38+
/ "surfaces_onnx_files"
39+
/ filename
40+
)
41+
if local_path.exists():
42+
return local_path
43+
return None
44+
45+
46+
def _get_installed_onnx_path(filename: str) -> Optional[Path]:
47+
"""Get file path from the installed surfaces-onnx-files package."""
48+
try:
49+
if sys.version_info >= (3, 9):
50+
from importlib.resources import as_file, files
51+
52+
resource = files(_ONNX_PACKAGE).joinpath(filename)
53+
try:
54+
with as_file(resource) as path:
55+
if path.exists():
56+
return path
57+
except (TypeError, FileNotFoundError):
58+
return None
59+
else:
60+
# Python 3.8 fallback
61+
try:
62+
from importlib_resources import as_file, files
63+
64+
resource = files(_ONNX_PACKAGE).joinpath(filename)
65+
with as_file(resource) as path:
66+
if path.exists():
67+
return path
68+
except ImportError:
69+
import importlib.resources as pkg_resources
70+
71+
try:
72+
with pkg_resources.path(_ONNX_PACKAGE, filename) as path:
73+
if path.exists():
74+
return path
75+
except (ModuleNotFoundError, FileNotFoundError, TypeError):
76+
return None
77+
except ModuleNotFoundError:
78+
return None
79+
return None
80+
81+
82+
def _is_onnx_package_installed() -> bool:
83+
"""Check if the surfaces-onnx-files package is installed."""
84+
try:
85+
if sys.version_info >= (3, 9):
86+
from importlib.resources import files
87+
88+
files(_ONNX_PACKAGE)
89+
else:
90+
try:
91+
from importlib_resources import files
92+
93+
files(_ONNX_PACKAGE)
94+
except ImportError:
95+
import importlib
96+
97+
importlib.import_module(_ONNX_PACKAGE)
98+
return True
99+
except ModuleNotFoundError:
100+
return False
101+
102+
103+
def get_onnx_file(filename: str) -> Optional[Path]:
104+
"""Get path to an ONNX model file.
105+
106+
Checks in order:
107+
1. Local models directory (for newly trained models)
108+
2. Local data package directory (for development)
109+
3. Installed surfaces-onnx-files package
110+
111+
Parameters
112+
----------
113+
filename : str
114+
Name of the ONNX file (e.g., "k_neighbors_regressor.onnx").
115+
116+
Returns
117+
-------
118+
Path or None
119+
Path to the file if found, None otherwise.
120+
"""
121+
# Check local models directory first (for newly trained models)
122+
trained_path = _get_trained_model_path(filename)
123+
if trained_path is not None:
124+
return trained_path
125+
126+
# Check local data package directory (for development)
127+
local_path = _get_local_onnx_path(filename)
128+
if local_path is not None:
129+
return local_path
130+
131+
# Check installed package
132+
installed_path = _get_installed_onnx_path(filename)
133+
if installed_path is not None:
134+
return installed_path
135+
136+
return None
137+
138+
139+
def get_surrogate_model_path(function_name: str) -> Optional[Path]:
140+
"""Get path to a pre-trained surrogate model.
141+
142+
Parameters
143+
----------
144+
function_name : str
145+
Name of the function (e.g., "k_neighbors_classifier").
146+
147+
Returns
148+
-------
149+
Path or None
150+
Path to the ONNX file if it exists, None otherwise.
151+
"""
152+
return get_onnx_file(f"{function_name}.onnx")
153+
154+
155+
def get_metadata_path(function_name: str) -> Optional[Path]:
156+
"""Get path to a surrogate model's metadata file.
157+
158+
Parameters
159+
----------
160+
function_name : str
161+
Name of the function (e.g., "k_neighbors_classifier").
162+
163+
Returns
164+
-------
165+
Path or None
166+
Path to the .meta.json file if it exists, None otherwise.
167+
"""
168+
return get_onnx_file(f"{function_name}.onnx.meta.json")
169+
170+
171+
def get_validity_model_path(function_name: str) -> Optional[Path]:
172+
"""Get path to a surrogate model's validity model.
173+
174+
Parameters
175+
----------
176+
function_name : str
177+
Name of the function (e.g., "k_neighbors_classifier").
178+
179+
Returns
180+
-------
181+
Path or None
182+
Path to the .validity.onnx file if it exists, None otherwise.
183+
"""
184+
return get_onnx_file(f"{function_name}.validity.onnx")

0 commit comments

Comments
 (0)