diff --git a/src/athena_api/client.py b/src/athena_api/client.py index 00eb1637..dbca4653 100644 --- a/src/athena_api/client.py +++ b/src/athena_api/client.py @@ -81,7 +81,7 @@ def _request( self, method: str, endpoint: str, - request_model: BaseModel | None = None, + request_model: BaseModel | list[BaseModel] | None = None, response_cls=None, ): """ @@ -90,7 +90,7 @@ def _request( Args: method: HTTP method (get, post, put, delete) endpoint: API endpoint path - request_model: Optional request data model + request_model: Optional request data model or list of models response_cls: Optional response class to parse the response Returns: @@ -106,7 +106,10 @@ def _request( json_data = None if request_model: - if hasattr(request_model, "model_dump"): + if isinstance(request_model, list): + # It's a list of Pydantic models + json_data = [m.model_dump(mode="json", exclude_none=True) for m in request_model] + elif hasattr(request_model, "model_dump"): # It's a Pydantic model json_data = request_model.model_dump(mode="json", exclude_none=True) else: @@ -313,6 +316,10 @@ def record_decision(self, decision: DecisionRecord) -> DecisionRecord: """Record a decision.""" return self._request("post", "api/v1/Decision", decision, DecisionRecord) + def record_decisions(self, decisions: list[DecisionRecord]) -> list[DecisionRecord]: + """Record multiple decisions.""" + return self._request("post", "api/v1/Decisions", decisions, DecisionRecord) + # Session endpoints def register_session(self, session: Session) -> Session: """Register a session.""" diff --git a/src/smartem_backend/consumer.py b/src/smartem_backend/consumer.py index e777cec2..1a068157 100755 --- a/src/smartem_backend/consumer.py +++ b/src/smartem_backend/consumer.py @@ -737,7 +737,7 @@ def handle_gridsquare_model_prediction(event_data: dict[str, Any]) -> None: metric_name=event.metric, ) else: - current_quality_prediction = event.prediction_value + current_quality_prediction.value = event.prediction_value session.add(current_quality_prediction) session.commit() @@ -785,7 +785,7 @@ def handle_foilhole_model_prediction(event_data: dict[str, Any]) -> None: metric_name=event.metric, ) else: - current_quality_prediction = event.prediction_value + current_quality_prediction.value = event.prediction_value session.add(current_quality_prediction) session.commit() @@ -814,7 +814,7 @@ def handle_multi_foilhole_model_prediction(event_data: dict[str, Any]) -> None: for fhuuid in event.foilhole_uuids ] with Session(db_engine) as session: - session.add(quality_predictions) + session.add_all(quality_predictions) current_quality_predictions = session.exec( select(CurrentQualityPrediction) .where(CurrentQualityPrediction.foilhole_uuid.in_(event.foilhole_uuids))