diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index f8e9d8186074..fdaef13e17b0 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -970,7 +970,11 @@ class GetChartDataRequest(QueryCacheControl): identifier: int | str = Field(description="Chart identifier (ID, UUID)") limit: int | None = Field( - default=100, description="Maximum number of data rows to return" + default=None, + description=( + "Maximum number of data rows to return. If not specified, uses the " + "chart's configured row limit." + ), ) format: Literal["json", "csv", "excel"] = Field( default="json", description="Data export format" diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index 4f6a931c6c32..d4d1804af4a4 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -23,6 +23,7 @@ from typing import Any, Dict, List, TYPE_CHECKING from fastmcp import Context +from flask import current_app from superset_core.mcp import tool if TYPE_CHECKING: @@ -36,11 +37,13 @@ PerformanceMetadata, ) from superset.mcp_service.utils.cache_utils import get_cache_status_from_result +from superset.mcp_service.utils.schema_utils import parse_request logger = logging.getLogger(__name__) @tool(tags=["data"]) +@parse_request(GetChartDataRequest) async def get_chart_data( # noqa: C901 request: GetChartDataRequest, ctx: Context ) -> ChartData | ChartError: @@ -52,6 +55,7 @@ async def get_chart_data( # noqa: C901 - Numeric ID or UUID lookup - Multiple formats: json, csv, excel - Cache control: use_cache, force_refresh, cache_timeout + - Optional row limit override (respects chart's configured limits) Returns underlying data in requested format with cache status. """ @@ -121,40 +125,71 @@ async def get_chart_data( # noqa: C901 try: await ctx.report_progress(2, 4, "Preparing data query") - # Get chart data using the existing API + from superset.charts.schemas import ChartDataQueryContextSchema from superset.commands.chart.data.get_data_command import ChartDataCommand - from superset.common.query_context_factory import QueryContextFactory - # Parse the form_data to get query context - form_data = utils_json.loads(chart.params) if chart.params else {} - await ctx.debug( - "Chart form data parsed: has_filters=%s, has_groupby=%s, has_metrics=%s" - % ( - bool(form_data.get("filters")), - bool(form_data.get("groupby")), - bool(form_data.get("metrics")), + # Use the chart's saved query_context - this is the key! + # The query_context contains all the information needed to reproduce + # the chart's data exactly as shown in the visualization + query_context_json = None + if chart.query_context: + try: + query_context_json = utils_json.loads(chart.query_context) + await ctx.debug( + "Using chart's saved query_context for data retrieval" + ) + except (TypeError, ValueError) as e: + await ctx.warning( + "Failed to parse chart query_context: %s" % str(e) + ) + + if query_context_json is None: + # Fallback: Chart has no saved query_context + # This can happen with older charts that haven't been re-saved + await ctx.warning( + "Chart has no saved query_context. " + "Data may not match the chart visualization exactly. " + "Consider re-saving the chart to enable full data retrieval." ) - ) + # Try to construct from form_data as a fallback + form_data = utils_json.loads(chart.params) if chart.params else {} + from superset.common.query_context_factory import QueryContextFactory + + factory = QueryContextFactory() + row_limit = ( + request.limit + or form_data.get("row_limit") + or current_app.config["ROW_LIMIT"] + ) + query_context = factory.create( + datasource={ + "id": chart.datasource_id, + "type": chart.datasource_type, + }, + queries=[ + { + "filters": form_data.get("filters", []), + "columns": form_data.get("groupby", []), + "metrics": form_data.get("metrics", []), + "row_limit": row_limit, + "order_desc": True, + } + ], + form_data=form_data, + force=request.force_refresh, + ) + else: + # Apply request overrides to the saved query_context + query_context_json["force"] = request.force_refresh - # Create a proper QueryContext using the factory with cache control - factory = QueryContextFactory() - query_context = factory.create( - datasource={"id": chart.datasource_id, "type": chart.datasource_type}, - queries=[ - { - "filters": form_data.get("filters", []), - "columns": form_data.get("groupby", []), - "metrics": form_data.get("metrics", []), - "row_limit": request.limit or 100, - "order_desc": True, - # Apply cache control from request - "cache_timeout": request.cache_timeout, - } - ], - form_data=form_data, - # Use cache unless force_refresh is True - force=request.force_refresh, - ) + # Apply row limit if specified (respects chart's configured limits) + if request.limit: + for query in query_context_json.get("queries", []): + query["row_limit"] = request.limit + + # Create QueryContext from the saved context using the schema + # This is exactly how the API does it + query_context = ChartDataQueryContextSchema().load(query_context_json) await ctx.report_progress(3, 4, "Executing data query") await ctx.debug( diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py new file mode 100644 index 000000000000..30b61980a907 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Tests for the get_chart_data request schema +""" + +import pytest + +from superset.mcp_service.chart.schemas import GetChartDataRequest + + +class TestGetChartDataRequestSchema: + """Test the GetChartDataRequest schema validation.""" + + def test_default_request(self): + """Test creating request with all defaults.""" + request = GetChartDataRequest(identifier=1) + + assert request.identifier == 1 + assert request.limit is None # Uses chart's configured limit by default + assert request.format == "json" + assert request.use_cache is True + assert request.force_refresh is False + assert request.cache_timeout is None + + def test_request_with_uuid_identifier(self): + """Test creating request with UUID identifier.""" + uuid = "a1b2c3d4-5678-90ab-cdef-1234567890ab" + request = GetChartDataRequest(identifier=uuid) + + assert request.identifier == uuid + + def test_request_with_custom_limit(self): + """Test creating request with custom limit.""" + request = GetChartDataRequest(identifier=1, limit=500) + + assert request.limit == 500 + + def test_request_with_csv_format(self): + """Test creating request with CSV format.""" + request = GetChartDataRequest(identifier=1, format="csv") + + assert request.format == "csv" + + def test_request_with_excel_format(self): + """Test creating request with Excel format.""" + request = GetChartDataRequest(identifier=1, format="excel") + + assert request.format == "excel" + + def test_request_with_cache_control(self): + """Test creating request with cache control options.""" + request = GetChartDataRequest( + identifier=1, + use_cache=False, + force_refresh=True, + cache_timeout=3600, + ) + + assert request.use_cache is False + assert request.force_refresh is True + assert request.cache_timeout == 3600 + + def test_invalid_format(self): + """Test that invalid format raises validation error.""" + with pytest.raises( + ValueError, match="Input should be 'json', 'csv' or 'excel'" + ): + GetChartDataRequest(identifier=1, format="invalid") + + def test_model_dump_serialization(self): + """Test that the request serializes correctly for JSON.""" + request = GetChartDataRequest( + identifier=123, + limit=50, + format="json", + ) + + data = request.model_dump() + + assert isinstance(data, dict) + assert data["identifier"] == 123 + assert data["limit"] == 50 + assert data["format"] == "json"