Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/athena_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _request(
self,
method: str,
endpoint: str,
request_model: BaseModel | None = None,
request_model: BaseModel | list[BaseModel] | None = None,
response_cls=None,
):
"""
Expand All @@ -90,7 +90,7 @@ def _request(
Args:
method: HTTP method (get, post, put, delete)
endpoint: API endpoint path
request_model: Optional request data model
request_model: Optional request data model or list of models
response_cls: Optional response class to parse the response

Returns:
Expand All @@ -106,7 +106,10 @@ def _request(
json_data = None

if request_model:
if hasattr(request_model, "model_dump"):
if isinstance(request_model, list):
# It's a list of Pydantic models
json_data = [m.model_dump(mode="json", exclude_none=True) for m in request_model]
elif hasattr(request_model, "model_dump"):
# It's a Pydantic model
json_data = request_model.model_dump(mode="json", exclude_none=True)
else:
Expand Down Expand Up @@ -313,6 +316,10 @@ def record_decision(self, decision: DecisionRecord) -> DecisionRecord:
"""Record a decision."""
return self._request("post", "api/v1/Decision", decision, DecisionRecord)

def record_decisions(self, decisions: list[DecisionRecord]) -> list[DecisionRecord]:
"""Record multiple decisions."""
return self._request("post", "api/v1/Decisions", decisions, DecisionRecord)

# Session endpoints
def register_session(self, session: Session) -> Session:
"""Register a session."""
Expand Down
6 changes: 3 additions & 3 deletions src/smartem_backend/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def handle_gridsquare_model_prediction(event_data: dict[str, Any]) -> None:
metric_name=event.metric,
)
else:
current_quality_prediction = event.prediction_value
current_quality_prediction.value = event.prediction_value
session.add(current_quality_prediction)
session.commit()

Expand Down Expand Up @@ -785,7 +785,7 @@ def handle_foilhole_model_prediction(event_data: dict[str, Any]) -> None:
metric_name=event.metric,
)
else:
current_quality_prediction = event.prediction_value
current_quality_prediction.value = event.prediction_value
session.add(current_quality_prediction)
session.commit()

Expand Down Expand Up @@ -814,7 +814,7 @@ def handle_multi_foilhole_model_prediction(event_data: dict[str, Any]) -> None:
for fhuuid in event.foilhole_uuids
]
with Session(db_engine) as session:
session.add(quality_predictions)
session.add_all(quality_predictions)
current_quality_predictions = session.exec(
select(CurrentQualityPrediction)
.where(CurrentQualityPrediction.foilhole_uuid.in_(event.foilhole_uuids))
Expand Down