Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/epu_data_intake/fs_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,12 @@ def parse_grid_dir(grid_data_dir: str, datastore: InMemoryDataStore) -> str:
grid_uuid=grid.uuid, # Set reference to parent grid
)

if grid.atlas_data.gridsquare_positions.get(int(gridsquare_id)) is not None:
# Check if atlas data exists and has gridsquare positions before accessing
if (
grid.atlas_data is not None
and grid.atlas_data.gridsquare_positions is not None
and grid.atlas_data.gridsquare_positions.get(int(gridsquare_id)) is not None
):
found_grid_square = datastore.find_gridsquare_by_natural_id(gridsquare_id)
gridsquare.uuid = found_grid_square.uuid
datastore.update_gridsquare(gridsquare)
Expand Down
73 changes: 41 additions & 32 deletions src/epu_data_intake/fs_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,39 +299,48 @@ def _on_atlas_detected(self, path: str, grid_uuid: str, is_new_file: bool = True
grid.atlas_data = atlas_data
self.datastore.update_grid(grid)
logging.debug(f"Updated atlas_data for grid: {grid_uuid}")
self.datastore.create_atlas(grid.atlas_data)
gs_uuid_map = {}
for gsid, gsp in grid.atlas_data.gridsquare_positions.items():
gridsquare = GridSquareData(
gridsquare_id=str(gsid),
metadata=None,
grid_uuid=grid.uuid,
center_x=gsp.center[0],
center_y=gsp.center[1],
size_width=gsp.size[0],
size_height=gsp.size[1],
)
# need to check if each square exists already
if found_grid_square := self.datastore.find_gridsquare_by_natural_id(str(gsid)):
gridsquare.uuid = found_grid_square.uuid
self.datastore.update_gridsquare(gridsquare)
gs_uuid_map[str(gsid)] = gridsquare.uuid
else:
self.datastore.create_gridsquare(gridsquare)
gs_uuid_map[str(gsid)] = gridsquare.uuid
logging.debug(f"Registered all squares for grid: {grid_uuid}")
for atlastile in grid.atlas_data.tiles:
for gsid, gs_tile_pos in atlastile.gridsquare_positions.items():
for pos in gs_tile_pos:
self.datastore.link_atlastile_to_gridsquare(
AtlasTileGridSquarePositionData(
gridsquare_uuid=gs_uuid_map[gsid],
tile_uuid=atlastile.uuid,
position=pos.position,
size=pos.size,
)

# Only proceed if atlas data parsing was successful and has gridsquare positions
if atlas_data is not None:
self.datastore.create_atlas(grid.atlas_data)
gs_uuid_map = {}

# Check if gridsquare_positions exists before iterating
if atlas_data.gridsquare_positions is not None:
for gsid, gsp in atlas_data.gridsquare_positions.items():
gridsquare = GridSquareData(
gridsquare_id=str(gsid),
metadata=None,
grid_uuid=grid.uuid,
center_x=gsp.center[0],
center_y=gsp.center[1],
size_width=gsp.size[0],
size_height=gsp.size[1],
)
logging.debug(f"Linked squares to tiles for gird: {grid_uuid}")
# need to check if each square exists already
if found_grid_square := self.datastore.find_gridsquare_by_natural_id(str(gsid)):
gridsquare.uuid = found_grid_square.uuid
self.datastore.update_gridsquare(gridsquare)
gs_uuid_map[str(gsid)] = gridsquare.uuid
else:
self.datastore.create_gridsquare(gridsquare)
gs_uuid_map[str(gsid)] = gridsquare.uuid
logging.debug(f"Registered all squares for grid: {grid_uuid}")

# Process atlas tiles only if atlas data and tiles exist
if atlas_data.tiles:
for atlastile in atlas_data.tiles:
for gsid, gs_tile_pos in atlastile.gridsquare_positions.items():
for pos in gs_tile_pos:
self.datastore.link_atlastile_to_gridsquare(
AtlasTileGridSquarePositionData(
gridsquare_uuid=gs_uuid_map[gsid],
tile_uuid=atlastile.uuid,
position=pos.position,
size=pos.size,
)
)
logging.debug(f"Linked squares to tiles for grid: {grid_uuid}")

def _on_gridsquare_metadata_detected(self, path: str, grid_uuid: str, is_new_file: bool = True):
logging.info(f"Gridsquare metadata {'detected' if is_new_file else 'updated'}: {path}")
Expand Down
2 changes: 1 addition & 1 deletion src/epu_data_intake/model/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class AtlasData(BaseModel):
id: str
acquisition_date: datetime
storage_folder: str
description: str
description: str | None = None
name: str
tiles: list[AtlasTileData]
gridsquare_positions: dict[int, GridSquarePosition] | None
Expand Down
17 changes: 17 additions & 0 deletions src/smartem_decisions/appconfig.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
app:
name: smartem-decisions
# There is technically no guarantee that every micrograph will be covered by at most 2 batches.
# The size of the batches basically makes incredibly unlikely but there is nothing to guarantee it.
particle_select_batch_size: 50000

rabbitmq:
queue_name: smartem_decisions
routing_key: smartem_decisions

database:
# SQLAlchemy connection pool settings
pool_size: 10 # Number of connections to maintain in pool
max_overflow: 20 # Additional connections beyond pool_size
pool_timeout: 30 # Seconds to wait for connection from pool
pool_recycle: 3600 # Seconds after which connection is recreated
pool_pre_ping: true # Validate connections before use (recommended for production)
20 changes: 14 additions & 6 deletions src/smartem_decisions/cli/initialise_prediction_model_weights.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import typer
from sqlalchemy.engine import Engine
from sqlmodel import Session, select

from smartem_decisions.model.database import Grid, QualityPredictionModel, QualityPredictionModelWeight
from smartem_decisions.utils import logger, setup_postgres_connection
from smartem_decisions.utils import get_db_engine, logger


def initialise_all_models_for_grid(grid_uuid: str) -> None:
def initialise_all_models_for_grid(grid_uuid: str, engine: Engine = None) -> None:
"""
Initialise prediction model weights for all available models for a specific grid.

Args:
grid_uuid: UUID of the grid to initialise weights for
default_weight: Default weight value to assign (default: DEFAULT_PREDICTION_MODEL_WEIGHT)
engine: Optional database engine (uses singleton if not provided)
"""
engine = setup_postgres_connection()
if engine is None:
engine = get_db_engine()

with Session(engine) as sess:
# Get all available prediction models
models = sess.exec(select(QualityPredictionModel)).all()
Expand Down Expand Up @@ -45,16 +48,21 @@ def initialise_all_models_for_grid(grid_uuid: str) -> None:
sess.commit()


def initialise_prediction_model_for_grid(name: str, weight: float, grid_uuid: str | None = None) -> None:
def initialise_prediction_model_for_grid(
name: str, weight: float, grid_uuid: str | None = None, engine: Engine = None
) -> None:
"""
Initialise a single prediction model weight for a grid (CLI interface).

Args:
name: Prediction model name
weight: Weight value to assign
grid_uuid: Grid UUID (if None, uses first available grid)
engine: Optional database engine (uses singleton if not provided)
"""
engine = setup_postgres_connection()
if engine is None:
engine = get_db_engine()

with Session(engine) as sess:
if grid_uuid is None:
grid = sess.exec(select(Grid)).first()
Expand Down
27 changes: 21 additions & 6 deletions src/smartem_decisions/cli/random_model_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Annotated

import typer
from sqlalchemy.engine import Engine
from sqlmodel import Session, select

from smartem_decisions.model.database import FoilHole, Grid, GridSquare, QualityPrediction, QualityPredictionModel
from smartem_decisions.utils import logger, setup_postgres_connection
from smartem_decisions.utils import get_db_engine, logger

DEFAULT_PREDICTION_RANGE = (0.0, 1.0)

Expand All @@ -17,10 +18,14 @@ def generate_random_predictions(
level: Annotated[
str, typer.Option(help="Magnification level at which to generate predictions. Options are 'hole' or 'square'")
] = "hole",
engine: Engine = None,
) -> None:
if level not in ("hole", "square"):
raise ValueError(f"Level must be set to either 'hole' or 'square' not {level}")
engine = setup_postgres_connection()

if engine is None:
engine = get_db_engine()

with Session(engine) as sess:
if grid_uuid is None:
grid = sess.exec(select(Grid)).first()
Expand Down Expand Up @@ -57,15 +62,20 @@ def generate_random_predictions(
return None


def generate_predictions_for_gridsquare(gridsquare_uuid: str, grid_uuid: str | None = None) -> None:
def generate_predictions_for_gridsquare(
gridsquare_uuid: str, grid_uuid: str | None = None, engine: Engine = None
) -> None:
"""
Generate random predictions for a single gridsquare using all available models.

Args:
gridsquare_uuid: UUID of the gridsquare to generate predictions for
grid_uuid: UUID of the parent grid (optional, will be looked up if not provided)
engine: Optional database engine (uses singleton if not provided)
"""
engine = setup_postgres_connection()
if engine is None:
engine = get_db_engine()

with Session(engine) as sess:
# Get all available prediction models
models = sess.exec(select(QualityPredictionModel)).all()
Expand Down Expand Up @@ -113,15 +123,20 @@ def generate_predictions_for_gridsquare(gridsquare_uuid: str, grid_uuid: str | N
logger.info(f"Generated {len(predictions)} predictions for gridsquare {gridsquare_uuid}")


def generate_predictions_for_foilhole(foilhole_uuid: str, gridsquare_uuid: str | None = None) -> None:
def generate_predictions_for_foilhole(
foilhole_uuid: str, gridsquare_uuid: str | None = None, engine: Engine = None
) -> None:
"""
Generate random predictions for a single foilhole using all available models.

Args:
foilhole_uuid: UUID of the foilhole to generate predictions for
gridsquare_uuid: UUID of the parent gridsquare (optional, for validation if provided)
engine: Optional database engine (uses singleton if not provided)
"""
engine = setup_postgres_connection()
if engine is None:
engine = get_db_engine()

with Session(engine) as sess:
# Get all available prediction models
models = sess.exec(select(QualityPredictionModel)).all()
Expand Down
24 changes: 17 additions & 7 deletions src/smartem_decisions/cli/random_prior_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import time

import typer
from sqlalchemy.engine import Engine
from sqlmodel import Session, select

from smartem_decisions.model.database import FoilHole, Grid, GridSquare, Micrograph
from smartem_decisions.predictions.update import prior_update
from smartem_decisions.utils import logger, setup_postgres_connection
from smartem_decisions.utils import get_db_engine, logger

# Default time ranges for processing steps (in seconds)
DEFAULT_MOTION_CORRECTION_DELAY = (1.0, 3.0)
Expand All @@ -20,8 +21,11 @@ def perform_random_updates(
grid_uuid: str | None = None,
random_range: tuple[float, float] = (0, 1),
origin: str = "motion_correction",
engine: Engine = None,
) -> None:
engine = setup_postgres_connection()
if engine is None:
engine = get_db_engine()

with Session(engine) as sess:
if grid_uuid is None:
grid = sess.exec(select(Grid)).first()
Expand All @@ -37,15 +41,19 @@ def perform_random_updates(
return None


def simulate_processing_pipeline(micrograph_uuid: str) -> None:
def simulate_processing_pipeline(micrograph_uuid: str, engine: Engine = None) -> None:
"""
Simulate the data processing pipeline for a micrograph with random delays.

Pipeline: motion correction → ctf → particle picking → particle selection

Args:
micrograph_uuid: UUID of the micrograph to process
engine: Optional database engine (uses singleton if not provided)
"""
if engine is None:
engine = get_db_engine()

processing_steps = [
("motion_correction", DEFAULT_MOTION_CORRECTION_DELAY),
("ctf", DEFAULT_CTF_DELAY),
Expand All @@ -61,9 +69,8 @@ def simulate_processing_pipeline(micrograph_uuid: str) -> None:
logger.debug(f"Simulating {step_name} for micrograph {micrograph_uuid}, delay: {delay:.2f}s")
time.sleep(delay)

# Perform random weight update for this step
# Perform random weight update for this step - reuse the same engine
try:
engine = setup_postgres_connection()
with Session(engine) as sess:
# Generate random quality result (True/False)
quality_result = random.choice([True, False])
Expand All @@ -76,16 +83,19 @@ def simulate_processing_pipeline(micrograph_uuid: str) -> None:
logger.info(f"Completed processing pipeline simulation for micrograph {micrograph_uuid}")


def simulate_processing_pipeline_async(micrograph_uuid: str) -> None:
def simulate_processing_pipeline_async(micrograph_uuid: str, engine: Engine = None) -> None:
"""
Start the processing pipeline simulation in a background thread.

Args:
micrograph_uuid: UUID of the micrograph to process
engine: Optional database engine (uses singleton if not provided)
"""
if engine is None:
engine = get_db_engine()

def run_simulation():
simulate_processing_pipeline(micrograph_uuid)
simulate_processing_pipeline(micrograph_uuid, engine)

# Start simulation in background thread so it doesn't block the consumer
thread = threading.Thread(target=run_simulation, daemon=True)
Expand Down
8 changes: 0 additions & 8 deletions src/smartem_decisions/config.yaml

This file was deleted.

12 changes: 8 additions & 4 deletions src/smartem_decisions/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
MicrographUpdatedEvent,
)
from smartem_decisions.utils import (
get_db_engine,
load_conf,
rmq_consumer,
)
Expand All @@ -53,6 +54,9 @@
log_manager = LogManager.get_instance("smartem_decisions")
logger = log_manager.configure(LogConfig(level=logging.ERROR, console=True))

# Get singleton database engine for reuse across all event handlers
db_engine = get_db_engine()


def handle_acquisition_created(event_data: dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -174,7 +178,7 @@ def handle_grid_created(event_data: dict[str, Any], channel, delivery_tag) -> bo

# Initialise prediction model weights for all available models
try:
initialise_all_models_for_grid(event.uuid)
initialise_all_models_for_grid(event.uuid, engine=db_engine)
logger.info(f"Successfully initialised prediction model weights for grid {event.uuid}")
except Exception as weight_init_error:
logger.error(f"Failed to initialise prediction model weights for grid {event.uuid}: {weight_init_error}")
Expand Down Expand Up @@ -245,7 +249,7 @@ def handle_gridsquare_created(event_data: dict[str, Any], channel, delivery_tag)

# Generate random predictions for all available models
try:
generate_predictions_for_gridsquare(event.uuid, event.grid_uuid)
generate_predictions_for_gridsquare(event.uuid, event.grid_uuid, engine=db_engine)
logger.info(f"Successfully generated predictions for gridsquare {event.uuid}")
except Exception as prediction_error:
logger.error(f"Failed to generate predictions for gridsquare {event.uuid}: {prediction_error}")
Expand Down Expand Up @@ -316,7 +320,7 @@ def handle_foilhole_created(event_data: dict[str, Any], channel, delivery_tag) -

# Generate random predictions for all available models
try:
generate_predictions_for_foilhole(event.uuid, event.gridsquare_uuid)
generate_predictions_for_foilhole(event.uuid, event.gridsquare_uuid, engine=db_engine)
logger.info(f"Successfully generated predictions for foilhole {event.uuid}")
except Exception as prediction_error:
logger.error(f"Failed to generate predictions for foilhole {event.uuid}: {prediction_error}")
Expand Down Expand Up @@ -387,7 +391,7 @@ def handle_micrograph_created(event_data: dict[str, Any], channel, delivery_tag)

# Start simulated processing pipeline in background
try:
simulate_processing_pipeline_async(event.uuid)
simulate_processing_pipeline_async(event.uuid, engine=db_engine)
logger.info(f"Started processing pipeline simulation for micrograph {event.uuid}")
except Exception as simulation_error:
logger.error(f"Failed to start processing simulation for micrograph {event.uuid}: {simulation_error}")
Expand Down
Loading