Skip to content

Commit 96b4ca6

Browse files
committed
adapt dashboard and loader to new onnx package setup
1 parent 15d8429 commit 96b4ca6

File tree

2 files changed

+14
-32
lines changed

2 files changed

+14
-32
lines changed

src/surfaces/_surrogates/_dashboard/sync.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
from typing import Optional
1616

1717
from .database import init_db, upsert_surrogate
18-
19-
# Paths
20-
MODELS_DIR = Path(__file__).parent.parent / "models"
18+
from .._onnx_utils import get_metadata_path, get_surrogate_model_path
2119

2220

2321
def compute_file_hash(file_path: Path) -> str:
2422
"""Compute SHA256 hash of a file."""
25-
if not file_path.exists():
23+
if file_path is None or not file_path.exists():
2624
return ""
2725
sha256 = hashlib.sha256()
2826
with open(file_path, "rb") as f:
@@ -46,18 +44,16 @@ def get_function_type(function_name: str) -> str:
4644
return "unknown"
4745

4846

49-
def sync_all(models_dir: Optional[Path] = None, db_path: Optional[Path] = None) -> dict:
47+
def sync_all(db_path: Optional[Path] = None) -> dict:
5048
"""Sync all ML functions from registry to database.
5149
5250
This function:
5351
1. Gets all registered ML functions from the registry
54-
2. Checks if each has a .onnx and .meta.json file
52+
2. Checks if each has a .onnx and .meta.json file (in data package or local)
5553
3. Upserts the data into SQLite
5654
5755
Parameters
5856
----------
59-
models_dir : Path, optional
60-
Directory containing .onnx and .meta.json files.
6157
db_path : Path, optional
6258
Path to SQLite database.
6359
@@ -66,8 +62,6 @@ def sync_all(models_dir: Optional[Path] = None, db_path: Optional[Path] = None)
6662
dict
6763
Sync statistics.
6864
"""
69-
models_dir = models_dir or MODELS_DIR
70-
7165
# Initialize database
7266
init_db(db_path)
7367

@@ -83,10 +77,10 @@ def sync_all(models_dir: Optional[Path] = None, db_path: Optional[Path] = None)
8377
}
8478

8579
for function_name in functions:
86-
onnx_path = models_dir / f"{function_name}.onnx"
87-
meta_path = models_dir / f"{function_name}.onnx.meta.json"
80+
onnx_path = get_surrogate_model_path(function_name)
81+
meta_path = get_metadata_path(function_name)
8882

89-
has_surrogate = onnx_path.exists() and meta_path.exists()
83+
has_surrogate = onnx_path is not None and meta_path is not None
9084

9185
if has_surrogate:
9286
stats["with_surrogate"] += 1
@@ -114,7 +108,6 @@ def sync_all(models_dir: Optional[Path] = None, db_path: Optional[Path] = None)
114108

115109
def sync_single(
116110
function_name: str,
117-
models_dir: Optional[Path] = None,
118111
db_path: Optional[Path] = None,
119112
) -> bool:
120113
"""Sync a single function to the database.
@@ -123,8 +116,6 @@ def sync_single(
123116
----------
124117
function_name : str
125118
Name of the function to sync.
126-
models_dir : Path, optional
127-
Directory containing model files.
128119
db_path : Path, optional
129120
Path to SQLite database.
130121
@@ -133,12 +124,10 @@ def sync_single(
133124
bool
134125
True if surrogate exists, False otherwise.
135126
"""
136-
models_dir = models_dir or MODELS_DIR
137-
138-
onnx_path = models_dir / f"{function_name}.onnx"
139-
meta_path = models_dir / f"{function_name}.onnx.meta.json"
127+
onnx_path = get_surrogate_model_path(function_name)
128+
meta_path = get_metadata_path(function_name)
140129

141-
has_surrogate = onnx_path.exists() and meta_path.exists()
130+
has_surrogate = onnx_path is not None and meta_path is not None
142131

143132
if has_surrogate:
144133
metadata = load_meta_json(meta_path)
@@ -163,7 +152,6 @@ def sync_single(
163152

164153
def check_sync_needed(
165154
function_name: str,
166-
models_dir: Optional[Path] = None,
167155
db_path: Optional[Path] = None,
168156
) -> bool:
169157
"""Check if a function needs to be re-synced.
@@ -174,8 +162,6 @@ def check_sync_needed(
174162
----------
175163
function_name : str
176164
Name of the function to check.
177-
models_dir : Path, optional
178-
Directory containing model files.
179165
db_path : Path, optional
180166
Path to SQLite database.
181167
@@ -186,11 +172,10 @@ def check_sync_needed(
186172
"""
187173
from .database import get_surrogate
188174

189-
models_dir = models_dir or MODELS_DIR
190-
onnx_path = models_dir / f"{function_name}.onnx"
175+
onnx_path = get_surrogate_model_path(function_name)
191176

192177
# Get current hash
193-
current_hash = compute_file_hash(onnx_path) if onnx_path.exists() else None
178+
current_hash = compute_file_hash(onnx_path) if onnx_path is not None else None
194179

195180
# Get stored hash
196181
surrogate = get_surrogate(function_name, db_path)

src/surfaces/_surrogates/_surrogate_loader.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,9 @@ def get_surrogate_path(function_name: str) -> Optional[Path]:
212212
Path or None
213213
Path to the ONNX file if it exists, None otherwise.
214214
"""
215-
models_dir = Path(__file__).parent / "models"
216-
model_path = models_dir / f"{function_name}.onnx"
215+
from ._onnx_utils import get_surrogate_model_path
217216

218-
if model_path.exists():
219-
return model_path
220-
return None
217+
return get_surrogate_model_path(function_name)
221218

222219

223220
def load_surrogate(function_name: str) -> Optional[SurrogateLoader]:

0 commit comments

Comments
 (0)