Skip to content

Commit 161785e

Browse files
Luca CandelaLuca Candela
authored andcommitted
add minimal ci and strengthen episode ingestion coverage
1 parent 0c0cf36 commit 161785e

File tree

9 files changed

+497
-10
lines changed

9 files changed

+497
-10
lines changed

.github/workflows/ci.yml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-22.04
12+
13+
steps:
14+
- name: Checkout
15+
uses: actions/checkout@v4
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v5
19+
with:
20+
python-version: '3.12'
21+
22+
- name: Install uv
23+
shell: bash
24+
run: curl -LsSf https://astral.sh/uv/install.sh | sh
25+
26+
- name: Cache uv packages
27+
uses: actions/cache@v4
28+
with:
29+
path: ~/.cache/uv
30+
key: ${{ runner.os }}-uv-${{ hashFiles('uv.lock') }}
31+
restore-keys: |
32+
${{ runner.os }}-uv-
33+
34+
- name: Install dependencies
35+
run: uv sync --extra dev
36+
37+
- name: Run targeted tests
38+
run: |
39+
uv run pytest tests/orchestration/test_bulk_serialization.py tests/search/test_search_utils_filters.py
40+
uv run pytest tests/test_graphium_mock.py::test_add_episode_persists_nodes_and_edges

mcp_server/graphium_mcp/queues.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,58 @@
1-
"""Episode queue orchestration."""
1+
"""Episode queue orchestration with retry tracking."""
22

33
from __future__ import annotations
44

55
import asyncio
66
import logging
7+
from datetime import UTC, datetime
8+
from typing import Any
79

810
from . import state
911

1012
logger = logging.getLogger(__name__)
1113

14+
MAX_QUEUE_RETRIES = 3
15+
RETRY_BACKOFF_SECONDS = 0.5
16+
17+
18+
def _metadata(process_func: state.EpisodeProcessor) -> dict[str, Any]:
19+
metadata = getattr(process_func, 'queue_metadata', {})
20+
if not isinstance(metadata, dict):
21+
metadata = {}
22+
metadata.setdefault('attempts', 0)
23+
return metadata
24+
25+
26+
def _record_failure(group_id: str, metadata: dict[str, Any], exc: Exception) -> None:
27+
name = metadata.get('name', 'unknown')
28+
attempts = metadata.get('attempts', 0)
29+
failure = {
30+
'name': str(name),
31+
'error': f'{exc.__class__.__name__}: {exc}',
32+
'attempts': str(attempts),
33+
'timestamp': datetime.now(UTC).isoformat(),
34+
}
35+
36+
failures = state.queue_failures.setdefault(group_id, [])
37+
if name:
38+
failures = [entry for entry in failures if entry.get('name') != name]
39+
failures.append(failure)
40+
state.queue_failures[group_id] = failures
41+
42+
43+
def _clear_failure(group_id: str, metadata: dict[str, Any]) -> None:
44+
name = metadata.get('name')
45+
if not name:
46+
return
47+
failures = state.queue_failures.get(group_id)
48+
if not failures:
49+
return
50+
remaining = [entry for entry in failures if entry.get('name') != name]
51+
if remaining:
52+
state.queue_failures[group_id] = remaining
53+
else:
54+
state.queue_failures.pop(group_id, None)
55+
1256

1357
async def process_episode_queue(group_id: str) -> None:
1458
"""Process episodes for a specific group_id sequentially."""
@@ -18,14 +62,33 @@ async def process_episode_queue(group_id: str) -> None:
1862
try:
1963
while True:
2064
process_func = await state.episode_queues[group_id].get()
65+
metadata = _metadata(process_func)
66+
name = metadata.get('name', 'queued-episode')
2167
try:
2268
await process_func()
69+
metadata['attempts'] = 0
70+
_clear_failure(group_id, metadata)
2371
except Exception as exc: # pragma: no cover - defensive logging
24-
logger.error(
25-
'Error processing queued episode for group_id %s: %s',
72+
metadata['attempts'] = metadata.get('attempts', 0) + 1
73+
_record_failure(group_id, metadata, exc)
74+
attempt = metadata['attempts']
75+
logger.exception(
76+
"Error processing queued episode '%s' for group_id %s (attempt %s/%s)",
77+
name,
2678
group_id,
27-
exc,
79+
attempt,
80+
MAX_QUEUE_RETRIES,
2881
)
82+
if attempt < MAX_QUEUE_RETRIES:
83+
setattr(process_func, 'queue_metadata', metadata)
84+
await asyncio.sleep(RETRY_BACKOFF_SECONDS)
85+
await state.episode_queues[group_id].put(process_func)
86+
else:
87+
logger.error(
88+
"Episode '%s' for group_id %s exceeded max retries and will be discarded",
89+
name,
90+
group_id,
91+
)
2992
finally:
3093
state.episode_queues[group_id].task_done()
3194
except asyncio.CancelledError:
@@ -46,6 +109,9 @@ async def enqueue_episode(group_id: str, process_func: state.EpisodeProcessor) -
46109
if group_id not in state.episode_queues:
47110
state.episode_queues[group_id] = asyncio.Queue()
48111

112+
metadata = _metadata(process_func)
113+
setattr(process_func, 'queue_metadata', metadata)
114+
49115
await state.episode_queues[group_id].put(process_func)
50116

51117
if not state.queue_workers.get(group_id, False):
@@ -54,4 +120,4 @@ async def enqueue_episode(group_id: str, process_func: state.EpisodeProcessor) -
54120
return state.episode_queues[group_id].qsize()
55121

56122

57-
__all__ = ['enqueue_episode', 'process_episode_queue']
123+
__all__ = ['MAX_QUEUE_RETRIES', 'enqueue_episode', 'process_episode_queue']

mcp_server/graphium_mcp/state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
episode_queues: dict[str, asyncio.Queue[EpisodeProcessor]] = {}
1919
queue_workers: dict[str, bool] = {}
2020
graphium_init_error: str | None = None
21+
queue_failures: dict[str, list[dict[str, str]]] = {}
2122

2223

2324
def set_config(new_config: GraphiumConfig) -> None:
@@ -42,6 +43,7 @@ def set_init_error(error: str | None) -> None:
4243
'graphium_config',
4344
'graphium_init_error',
4445
'queue_workers',
46+
'queue_failures',
4547
'set_init_error',
4648
'set_client',
4749
'set_config',

mcp_server/graphium_mcp/status.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,17 @@ async def collect_status() -> StatusResponse:
3333
graphium = cast(Graphium, client)
3434
await graphium.driver.client.verify_connectivity() # type: ignore
3535

36+
status_value = 'ok'
37+
message = 'Graphium MCP server is running and connected to Neo4j'
38+
39+
if state.queue_failures:
40+
failure_count = sum(len(entries) for entries in state.queue_failures.values())
41+
status_value = 'warn'
42+
message = f'{message} (pending queue failures: {failure_count})'
43+
3644
return StatusResponse(
37-
status='ok',
38-
message='Graphium MCP server is running and connected to Neo4j',
45+
status=status_value,
46+
message=message,
3947
)
4048
except Exception as exc:
4149
logger.error('Error checking Neo4j connection: %s', exc)

mcp_server/graphium_mcp/tools.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,25 @@ async def process_episode() -> None:
106106
group_id_str,
107107
)
108108

109+
process_episode.queue_metadata = {
110+
'name': name,
111+
'group_id': group_id_str,
112+
'source': source_type.value,
113+
}
114+
109115
position = await enqueue_episode(group_id_str, process_episode)
110116

111-
return SuccessResponse(
112-
message=f"Episode '{name}' queued for processing (position: {position})"
113-
)
117+
pending_failures = state.queue_failures.get(group_id_str, [])
118+
message = f"Episode '{name}' queued for processing (position: {position})"
119+
if pending_failures:
120+
last_failure = pending_failures[-1]
121+
message = (
122+
f"{message}. Warning: {len(pending_failures)} prior failure(s); "
123+
f"last error from '{last_failure.get('name', 'unknown')}': "
124+
f"{last_failure.get('error', 'unknown error')}"
125+
)
126+
127+
return SuccessResponse(message=message)
114128
except Exception as exc:
115129
logger.exception(
116130
"Error queuing episode '%s' for group_id %s",

tests/mcp/test_episode_queue.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import asyncio
2+
import logging
3+
from contextlib import suppress
4+
5+
import pytest
6+
7+
from mcp_server.graphium_mcp import queues, state
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_enqueue_episode_retries_and_records_failures(monkeypatch, caplog):
12+
group_id = 'queue-test-group'
13+
attempts = 0
14+
retry_complete = asyncio.Event()
15+
16+
async def failing_processor():
17+
nonlocal attempts
18+
attempts += 1
19+
if attempts >= queues.MAX_QUEUE_RETRIES:
20+
retry_complete.set()
21+
raise RuntimeError('boom')
22+
23+
failing_processor.queue_metadata = {'name': 'test-episode'}
24+
25+
original_create_task = asyncio.create_task
26+
created_tasks: list[asyncio.Task] = []
27+
28+
def capture_task(coro):
29+
task = original_create_task(coro)
30+
created_tasks.append(task)
31+
return task
32+
33+
monkeypatch.setattr(asyncio, 'create_task', capture_task)
34+
caplog.set_level(logging.WARNING, logger='mcp_server.graphium_mcp.queues')
35+
state.queue_failures.pop(group_id, None)
36+
37+
try:
38+
position = await queues.enqueue_episode(group_id, failing_processor)
39+
assert position == 1
40+
41+
await asyncio.wait_for(retry_complete.wait(), timeout=3)
42+
await asyncio.wait_for(state.episode_queues[group_id].join(), timeout=3)
43+
44+
assert attempts == queues.MAX_QUEUE_RETRIES
45+
failures = state.queue_failures[group_id]
46+
assert failures[-1]['name'] == 'test-episode'
47+
assert failures[-1]['attempts'] == str(queues.MAX_QUEUE_RETRIES)
48+
assert any('exceeded max retries' in record.message for record in caplog.records)
49+
finally:
50+
for task in created_tasks:
51+
task.cancel()
52+
with suppress(asyncio.CancelledError):
53+
await task
54+
state.episode_queues.pop(group_id, None)
55+
state.queue_workers.pop(group_id, None)
56+
state.queue_failures.pop(group_id, None)
57+
monkeypatch.setattr(asyncio, 'create_task', original_create_task)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from datetime import UTC, datetime
2+
3+
from graphium_core.nodes import EpisodicNode, EpisodeType
4+
from graphium_core.orchestration import bulk
5+
6+
7+
def test_serialize_episodes_preserves_entity_edges():
8+
now = datetime.now(UTC)
9+
episode = EpisodicNode(
10+
name='episode-1',
11+
group_id='group-123',
12+
labels=['Conversation'],
13+
created_at=now,
14+
source=EpisodeType.message,
15+
source_description='chat message',
16+
content='user: hello world',
17+
valid_at=now,
18+
entity_edges=['edge-1', 'edge-2'],
19+
)
20+
21+
payloads = bulk._serialize_episodes([episode])
22+
23+
assert len(payloads) == 1
24+
payload = payloads[0]
25+
assert payload.uuid == episode.uuid
26+
assert payload.source == EpisodeType.message.value
27+
assert payload.entity_edges == ['edge-1', 'edge-2']

0 commit comments

Comments
 (0)