Skip to content
6 changes: 5 additions & 1 deletion superset/mcp_service/chart/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,11 @@ class GetChartDataRequest(QueryCacheControl):

identifier: int | str = Field(description="Chart identifier (ID, UUID)")
limit: int | None = Field(
default=100, description="Maximum number of data rows to return"
default=None,
description=(
"Maximum number of data rows to return. If not specified, uses the "
"chart's configured row limit."
),
)
format: Literal["json", "csv", "excel"] = Field(
default="json", description="Data export format"
Expand Down
95 changes: 65 additions & 30 deletions superset/mcp_service/chart/tool/get_chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Dict, List, TYPE_CHECKING

from fastmcp import Context
from flask import current_app
from superset_core.mcp import tool

if TYPE_CHECKING:
Expand All @@ -36,11 +37,13 @@
PerformanceMetadata,
)
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.schema_utils import parse_request

logger = logging.getLogger(__name__)


@tool(tags=["data"])
@parse_request(GetChartDataRequest)
async def get_chart_data( # noqa: C901
request: GetChartDataRequest, ctx: Context
) -> ChartData | ChartError:
Expand All @@ -52,6 +55,7 @@ async def get_chart_data( # noqa: C901
- Numeric ID or UUID lookup
- Multiple formats: json, csv, excel
- Cache control: use_cache, force_refresh, cache_timeout
- Optional row limit override (respects chart's configured limits)

Returns underlying data in requested format with cache status.
"""
Expand Down Expand Up @@ -121,40 +125,71 @@ async def get_chart_data( # noqa: C901

try:
await ctx.report_progress(2, 4, "Preparing data query")
# Get chart data using the existing API
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.common.query_context_factory import QueryContextFactory

# Parse the form_data to get query context
form_data = utils_json.loads(chart.params) if chart.params else {}
await ctx.debug(
"Chart form data parsed: has_filters=%s, has_groupby=%s, has_metrics=%s"
% (
bool(form_data.get("filters")),
bool(form_data.get("groupby")),
bool(form_data.get("metrics")),
# Use the chart's saved query_context - this is the key!
# The query_context contains all the information needed to reproduce
# the chart's data exactly as shown in the visualization
query_context_json = None
if chart.query_context:
try:
query_context_json = utils_json.loads(chart.query_context)
await ctx.debug(
"Using chart's saved query_context for data retrieval"
)
except (TypeError, ValueError) as e:
await ctx.warning(
"Failed to parse chart query_context: %s" % str(e)
)

if query_context_json is None:
# Fallback: Chart has no saved query_context
# This can happen with older charts that haven't been re-saved
await ctx.warning(
"Chart has no saved query_context. "
"Data may not match the chart visualization exactly. "
"Consider re-saving the chart to enable full data retrieval."
)
)
# Try to construct from form_data as a fallback
form_data = utils_json.loads(chart.params) if chart.params else {}
from superset.common.query_context_factory import QueryContextFactory

factory = QueryContextFactory()
row_limit = (
request.limit
or form_data.get("row_limit")
or current_app.config["ROW_LIMIT"]
)
query_context = factory.create(
datasource={
"id": chart.datasource_id,
"type": chart.datasource_type,
},
queries=[
{
"filters": form_data.get("filters", []),
"columns": form_data.get("groupby", []),
"metrics": form_data.get("metrics", []),
"row_limit": row_limit,
"order_desc": True,
}
],
form_data=form_data,
force=request.force_refresh,
)
else:
# Apply request overrides to the saved query_context
query_context_json["force"] = request.force_refresh

# Create a proper QueryContext using the factory with cache control
factory = QueryContextFactory()
query_context = factory.create(
datasource={"id": chart.datasource_id, "type": chart.datasource_type},
queries=[
{
"filters": form_data.get("filters", []),
"columns": form_data.get("groupby", []),
"metrics": form_data.get("metrics", []),
"row_limit": request.limit or 100,
"order_desc": True,
# Apply cache control from request
"cache_timeout": request.cache_timeout,
}
],
form_data=form_data,
# Use cache unless force_refresh is True
force=request.force_refresh,
)
# Apply row limit if specified (respects chart's configured limits)
if request.limit:
for query in query_context_json.get("queries", []):
query["row_limit"] = request.limit

# Create QueryContext from the saved context using the schema
# This is exactly how the API does it
query_context = ChartDataQueryContextSchema().load(query_context_json)

await ctx.report_progress(3, 4, "Executing data query")
await ctx.debug(
Expand Down
99 changes: 99 additions & 0 deletions tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Tests for the get_chart_data request schema
"""

import pytest

from superset.mcp_service.chart.schemas import GetChartDataRequest


class TestGetChartDataRequestSchema:
"""Test the GetChartDataRequest schema validation."""

def test_default_request(self):
"""Test creating request with all defaults."""
request = GetChartDataRequest(identifier=1)

assert request.identifier == 1
assert request.limit is None # Uses chart's configured limit by default
assert request.format == "json"
assert request.use_cache is True
assert request.force_refresh is False
assert request.cache_timeout is None

def test_request_with_uuid_identifier(self):
"""Test creating request with UUID identifier."""
uuid = "a1b2c3d4-5678-90ab-cdef-1234567890ab"
request = GetChartDataRequest(identifier=uuid)

assert request.identifier == uuid

def test_request_with_custom_limit(self):
"""Test creating request with custom limit."""
request = GetChartDataRequest(identifier=1, limit=500)

assert request.limit == 500

def test_request_with_csv_format(self):
"""Test creating request with CSV format."""
request = GetChartDataRequest(identifier=1, format="csv")

assert request.format == "csv"

def test_request_with_excel_format(self):
"""Test creating request with Excel format."""
request = GetChartDataRequest(identifier=1, format="excel")

assert request.format == "excel"

def test_request_with_cache_control(self):
"""Test creating request with cache control options."""
request = GetChartDataRequest(
identifier=1,
use_cache=False,
force_refresh=True,
cache_timeout=3600,
)

assert request.use_cache is False
assert request.force_refresh is True
assert request.cache_timeout == 3600

def test_invalid_format(self):
"""Test that invalid format raises validation error."""
with pytest.raises(
ValueError, match="Input should be 'json', 'csv' or 'excel'"
):
GetChartDataRequest(identifier=1, format="invalid")

def test_model_dump_serialization(self):
"""Test that the request serializes correctly for JSON."""
request = GetChartDataRequest(
identifier=123,
limit=50,
format="json",
)

data = request.model_dump()

assert isinstance(data, dict)
assert data["identifier"] == 123
assert data["limit"] == 50
assert data["format"] == "json"
Loading