Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@ wheels/
.mypy_cache/
tmp/
mcp.json
.mcp.json
server-info.json
.DS_Store
.claude/
uv.lock
28 changes: 28 additions & 0 deletions kumo_rfm_mcp/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import time
from typing import Any

from kumoai.utils.progress_logger import PlainProgressLogger
from typing_extensions import Self


class McpProgressLogger(PlainProgressLogger):
"""A progress logger safe for MCP stdio transport.

The base :class:`ProgressLogger` writes OSC escape sequences to
``sys.stdout`` in ``__enter__``/``__exit__``, which corrupts the
JSON-RPC stream when running over stdio transport. This subclass
overrides the context manager to preserve timing and log collection
while skipping all stdout writes.
"""
def __enter__(self) -> Self:
self._depth += 1
if self._depth == 1:
self.start_time = time.perf_counter()
# Skip on_enter (verbose=False) and skip stdout escape sequences.
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self._depth -= 1
if self._depth == 0:
self.end_time = time.perf_counter()
# Skip on_exit (verbose=False) and skip stdout escape sequences.
25 changes: 10 additions & 15 deletions kumo_rfm_mcp/tools/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from fastmcp import FastMCP
from fastmcp.exceptions import ToolError
from kumoai.experimental import rfm
from kumoai.experimental.rfm.infer.dtype import infer_dtype
from kumoai.graph import Edge
from kumoai.utils import ProgressLogger
from kumoapi.typing import Dtype, Stype
from pydantic import Field

Expand All @@ -22,6 +22,7 @@
UpdatedGraphMetadata,
UpdateGraphMetadata,
)
from kumo_rfm_mcp.logger import McpProgressLogger

_materialize_lock = asyncio.Lock()

Expand Down Expand Up @@ -49,7 +50,7 @@ def inspect_graph_metadata() -> GraphMetadata:
dtypes[column] = table[column].dtype
stypes[column] = table[column].stype
else:
dtypes[column] = rfm.utils.to_dtype(table._data[column])
dtypes[column] = infer_dtype(table._data[column])
stypes[column] = None
tables.append(
TableMetadata(
Expand Down Expand Up @@ -284,25 +285,18 @@ async def materialize_graph() -> MaterializedGraphInfo:

def _materialize_graph() -> rfm.KumoRFM:
try:
logger = ProgressLogger("Materializing graph")
logger = McpProgressLogger("Materializing graph", verbose=False)
return rfm.KumoRFM(session.graph, verbose=logger)
except Exception as e:
raise ToolError(f"Failed to materialize graph: {e}")

def _get_info(model: rfm.KumoRFM) -> MaterializedGraphInfo:
store = model._graph_store
store = model._sampler._graph_store
num_nodes = sum(len(df) for df in store.df_dict.values())
num_edges = sum(len(row) for row in store.row_dict.values())
time_ranges = {}
for table in session.graph.tables.values():
if table._time_column is None:
continue
time = store.df_dict[table.name][table._time_column]
if table.name in store.mask_dict.keys():
time = time[store.mask_dict[table.name]]
if len(time) == 0:
continue
time_ranges[table.name] = f"{time.min()} - {time.max()}"
for table_name, (min_t, max_t) in store.min_max_time_dict.items():
time_ranges[table_name] = f"{min_t} - {max_t}"

return MaterializedGraphInfo(
num_nodes=num_nodes,
Expand Down Expand Up @@ -348,11 +342,12 @@ async def lookup_table_rows(

def _lookup_table_rows() -> TableSourcePreview:
try:
node_ids = model._graph_store.get_node_id(
store = model._sampler._graph_store
node_ids = store.get_node_id(
table_name=table_name,
pkey=pd.Series(ids),
)
df = model._graph_store.df_dict[table_name].iloc[node_ids]
df = store.df_dict[table_name].iloc[node_ids]
except Exception as e:
raise ToolError(str(e)) from e

Expand Down
8 changes: 4 additions & 4 deletions kumo_rfm_mcp/tools/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pandas as pd
from fastmcp import FastMCP
from fastmcp.exceptions import ToolError
from kumoai.utils import ProgressLogger
from pydantic import Field

from kumo_rfm_mcp import (
Expand All @@ -14,6 +13,7 @@
PredictResponse,
SessionManager,
)
from kumo_rfm_mcp.logger import McpProgressLogger

query_doc = ("The predictive query string, e.g., "
"'PREDICT COUNT(orders.*, 0, 30, days)>0 FOR EACH users.user_id' "
Expand Down Expand Up @@ -139,7 +139,7 @@ async def predict(
anchor_time = pd.Timestamp(anchor_time)

def _predict() -> PredictResponse:
logger = ProgressLogger(query)
logger = McpProgressLogger(query, verbose=False)

try:
df = model.predict(
Expand Down Expand Up @@ -213,7 +213,7 @@ async def evaluate(
anchor_time = pd.Timestamp(anchor_time)

def _evaluate() -> EvaluateResponse:
logger = ProgressLogger(query)
logger = McpProgressLogger(query, verbose=False)

try:
df = model.evaluate(
Expand Down Expand Up @@ -290,7 +290,7 @@ async def explain(
anchor_time = pd.Timestamp(anchor_time)

def _explain() -> ExplanationResponse:
logger = ProgressLogger(query)
logger = McpProgressLogger(query, verbose=False)

try:
out = model.predict(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ keywords = [
]
requires-python = ">=3.10"
dependencies = [
"kumoai==2.10.1",
"kumoai>=2.15.0",
"fastmcp>=2.2.7,<3",
]
classifiers = [
Expand Down
Loading