1515from typing import Optional
1616
1717from .database import init_db , upsert_surrogate
18-
19- # Paths
20- MODELS_DIR = Path (__file__ ).parent .parent / "models"
18+ from .._onnx_utils import get_metadata_path , get_surrogate_model_path
2119
2220
2321def compute_file_hash (file_path : Path ) -> str :
2422 """Compute SHA256 hash of a file."""
25- if not file_path .exists ():
23+ if file_path is None or not file_path .exists ():
2624 return ""
2725 sha256 = hashlib .sha256 ()
2826 with open (file_path , "rb" ) as f :
@@ -46,18 +44,16 @@ def get_function_type(function_name: str) -> str:
4644 return "unknown"
4745
4846
49- def sync_all (models_dir : Optional [ Path ] = None , db_path : Optional [Path ] = None ) -> dict :
47+ def sync_all (db_path : Optional [Path ] = None ) -> dict :
5048 """Sync all ML functions from registry to database.
5149
5250 This function:
5351 1. Gets all registered ML functions from the registry
54- 2. Checks if each has a .onnx and .meta.json file
52+ 2. Checks if each has a .onnx and .meta.json file (in data package or local)
5553 3. Upserts the data into SQLite
5654
5755 Parameters
5856 ----------
59- models_dir : Path, optional
60- Directory containing .onnx and .meta.json files.
6157 db_path : Path, optional
6258 Path to SQLite database.
6359
@@ -66,8 +62,6 @@ def sync_all(models_dir: Optional[Path] = None, db_path: Optional[Path] = None)
6662 dict
6763 Sync statistics.
6864 """
69- models_dir = models_dir or MODELS_DIR
70-
7165 # Initialize database
7266 init_db (db_path )
7367
@@ -83,10 +77,10 @@ def sync_all(models_dir: Optional[Path] = None, db_path: Optional[Path] = None)
8377 }
8478
8579 for function_name in functions :
86- onnx_path = models_dir / f" { function_name } .onnx"
87- meta_path = models_dir / f" { function_name } .onnx.meta.json"
80+ onnx_path = get_surrogate_model_path ( function_name )
81+ meta_path = get_metadata_path ( function_name )
8882
89- has_surrogate = onnx_path . exists () and meta_path . exists ()
83+ has_surrogate = onnx_path is not None and meta_path is not None
9084
9185 if has_surrogate :
9286 stats ["with_surrogate" ] += 1
@@ -114,7 +108,6 @@ def sync_all(models_dir: Optional[Path] = None, db_path: Optional[Path] = None)
114108
115109def sync_single (
116110 function_name : str ,
117- models_dir : Optional [Path ] = None ,
118111 db_path : Optional [Path ] = None ,
119112) -> bool :
120113 """Sync a single function to the database.
@@ -123,8 +116,6 @@ def sync_single(
123116 ----------
124117 function_name : str
125118 Name of the function to sync.
126- models_dir : Path, optional
127- Directory containing model files.
128119 db_path : Path, optional
129120 Path to SQLite database.
130121
@@ -133,12 +124,10 @@ def sync_single(
133124 bool
134125 True if surrogate exists, False otherwise.
135126 """
136- models_dir = models_dir or MODELS_DIR
137-
138- onnx_path = models_dir / f"{ function_name } .onnx"
139- meta_path = models_dir / f"{ function_name } .onnx.meta.json"
127+ onnx_path = get_surrogate_model_path (function_name )
128+ meta_path = get_metadata_path (function_name )
140129
141- has_surrogate = onnx_path . exists () and meta_path . exists ()
130+ has_surrogate = onnx_path is not None and meta_path is not None
142131
143132 if has_surrogate :
144133 metadata = load_meta_json (meta_path )
@@ -163,7 +152,6 @@ def sync_single(
163152
164153def check_sync_needed (
165154 function_name : str ,
166- models_dir : Optional [Path ] = None ,
167155 db_path : Optional [Path ] = None ,
168156) -> bool :
169157 """Check if a function needs to be re-synced.
@@ -174,8 +162,6 @@ def check_sync_needed(
174162 ----------
175163 function_name : str
176164 Name of the function to check.
177- models_dir : Path, optional
178- Directory containing model files.
179165 db_path : Path, optional
180166 Path to SQLite database.
181167
@@ -186,11 +172,10 @@ def check_sync_needed(
186172 """
187173 from .database import get_surrogate
188174
189- models_dir = models_dir or MODELS_DIR
190- onnx_path = models_dir / f"{ function_name } .onnx"
175+ onnx_path = get_surrogate_model_path (function_name )
191176
192177 # Get current hash
193- current_hash = compute_file_hash (onnx_path ) if onnx_path . exists () else None
178+ current_hash = compute_file_hash (onnx_path ) if onnx_path is not None else None
194179
195180 # Get stored hash
196181 surrogate = get_surrogate (function_name , db_path )
0 commit comments