Skip to content

Commit cc4361a

Browse files
authored
Allow user provided crYOLO model (#311)
Search a configured directory for a crYOLO model before falling back on a configured default
1 parent 66dfb0e commit cc4361a

File tree

4 files changed

+59
-1
lines changed

4 files changed

+59
-1
lines changed

src/murfey/server/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from murfey.server.config import from_file, settings
5050
from murfey.server.gain import Camera, prepare_eer_gain, prepare_gain
5151
from murfey.server.murfey_db import murfey_db
52+
from murfey.server.spa.api import _cryolo_model_path
5253
from murfey.util.db import (
5354
AutoProcProgram,
5455
ClientEnvironment,
@@ -997,8 +998,11 @@ async def request_spa_preprocessing(
997998
Path(secure_filename(str(mrc_out))).parent.mkdir(
998999
parents=True, exist_ok=True
9991000
)
1001+
recipe_name = machine_config.recipes.get(
1002+
"em-spa-preprocess", "em-spa-preprocess"
1003+
)
10001004
zocalo_message = {
1001-
"recipes": ["em-spa-preprocess"],
1005+
"recipes": [recipe_name],
10021006
"parameters": {
10031007
"feedback_queue": machine_config.feedback_queue,
10041008
"node_creator_queue": machine_config.node_creator_queue,
@@ -1019,6 +1023,7 @@ async def request_spa_preprocessing(
10191023
"particle_diameter": proc_params["particle_diameter"] or 0,
10201024
"fm_int_file": proc_file.eer_fractionation_file,
10211025
"do_icebreaker_jobs": default_spa_parameters.do_icebreaker_jobs,
1026+
"cryolo_model_weights": str(_cryolo_model_path(visit_name)),
10221027
},
10231028
}
10241029
# log.info(f"Sending Zocalo message {zocalo_message}")

src/murfey/server/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class MachineConfig(BaseModel):
1717
rsync_basepath: Path
1818
murfey_db_credentials: str
1919
crypto_key: str
20+
default_model: Path
2021
display_name: str = ""
2122
image_path: Optional[Path] = None
2223
software_versions: Dict[str, str] = {}
@@ -41,11 +42,23 @@ class MachineConfig(BaseModel):
4142
processed_extra_directory: str = ""
4243
plugin_packages: Dict[str, Path] = {}
4344
software_settings_output_directories: Dict[str, List[str]] = {}
45+
recipes: Dict[str, str] = {
46+
"em-spa-bfactor": "em-spa-bfactor",
47+
"em-spa-class2d": "em-spa-class2d",
48+
"em-spa-class3d": "em-spa-class3d",
49+
"em-spa-preprocess": "em-spa-preprocess",
50+
"em-spa-refine": "em-spa-refine",
51+
"em-tomo-preprocess": "em-tomo-preprocess",
52+
"em-tomo-align": "em-tomo-align",
53+
}
4454

4555
# Find and download upstream directories
4656
upstream_data_directories: List[Path] = [] # Previous sessions
4757
upstream_data_download_directory: Optional[Path] = None # Set by microscope config
4858
upstream_data_tiff_locations: List[str] = ["processed"] # Location of CLEM TIFFs
59+
60+
model_search_directory: str = "processing"
61+
4962
failure_queue: str = ""
5063
auth_key: str = ""
5164
auth_algorithm: str = ""
@@ -93,6 +106,7 @@ def get_machine_config() -> MachineConfig:
93106
rsync_basepath=Path("dls/tmp"),
94107
murfey_db_credentials="",
95108
crypto_key="",
109+
default_model="/tmp/weights.h5",
96110
)
97111
if settings.murfey_machine_configuration:
98112
microscope = get_microscope()

src/murfey/server/spa/__init__.py

Whitespace-only changes.

src/murfey/server/spa/api.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import annotations
2+
3+
from datetime import datetime
4+
from functools import lru_cache
5+
from pathlib import Path
6+
7+
from fastapi import APIRouter
8+
from sqlmodel import select
9+
10+
from murfey.server import get_machine_config
11+
from murfey.server.murfey_db import murfey_db
12+
from murfey.util.db import Session
13+
14+
# Create APIRouter class object
15+
router = APIRouter()
16+
17+
18+
@lru_cache(maxsize=5)
19+
def _cryolo_model_path(visit: str) -> Path:
20+
machine_config = get_machine_config()
21+
if machine_config.model_search_directory:
22+
visit_directory = (
23+
machine_config.rsync_basepath
24+
/ (machine_config.rsync_module or "data")
25+
/ str(datetime.now().year)
26+
/ visit
27+
)
28+
possible_models = list(
29+
(visit_directory / machine_config.model_search_directory).glob("*.h5")
30+
)
31+
if possible_models:
32+
return sorted(possible_models, key=lambda x: x.stat().st_ctime)[-1]
33+
return machine_config.default_model
34+
35+
36+
@router.get("/sessions/{session_id}/cryolo_model")
37+
def get_cryolo_model_path(session_id: int, db=murfey_db):
38+
visit = db.exec(select(Session).where(Session.session_id == session_id)).one().visit
39+
return {"model_path": _cryolo_model_path(visit)}

0 commit comments

Comments
 (0)