diff --git a/README.md b/README.md index 93bdad15..c5571e5d 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,11 @@ Any questions? Ask on the Discord channel! [![Discord](https://img.shields.io/ba ## News 🔥🔥🔥 +- [07-10-2025] Data Formulator 0.2.2: Start with an analysis goal + - Some key frontend performance updates. + - You can start your exploration with a goal, or, tab and see if the agent can recommend some good exploration ideas for you. [Demo](https://github.com/microsoft/data-formulator/pull/176) -- [05-13-2025] Data Formulator 0.2.3 / 0.2.4: External Data Loader +- [05-13-2025] Data Formulator 0.2.1.3/4: External Data Loader - We introduced external data loader class to make import data easier. [Readme](https://github.com/microsoft/data-formulator/tree/main/py-src/data_formulator/data_loader) and [Demo](https://github.com/microsoft/data-formulator/pull/155) - Current data loaders: MySQL, Azure Data Explorer (Kusto), Azure Blob and Amazon S3 (json, parquet, csv). - [07-01-2025] Updated with: Postgresql, mssql. diff --git a/py-src/data_formulator/agents/__init__.py b/py-src/data_formulator/agents/__init__.py index 0fc88b50..4a201b8d 100644 --- a/py-src/data_formulator/agents/__init__.py +++ b/py-src/data_formulator/agents/__init__.py @@ -20,5 +20,5 @@ "SQLDataRecAgent", "DataLoadAgent", "SortDataAgent", - "DataCleanAgent" + "DataCleanAgent", ] \ No newline at end of file diff --git a/py-src/data_formulator/agents/agent_data_load.py b/py-src/data_formulator/agents/agent_data_load.py index 920ade5c..0dcef299 100644 --- a/py-src/data_formulator/agents/agent_data_load.py +++ b/py-src/data_formulator/agents/agent_data_load.py @@ -12,7 +12,10 @@ SYSTEM_PROMPT = '''You are a data scientist to help user infer data types based off the table provided by the user. -Given a dataset provided by the user, identify their type and semantic type, and provide a very short summary of the dataset. +Given a dataset provided by the user, +1. identify their type and semantic type +2. provide a very short summary of the dataset. +3. provide a list of (5-10) explorative questions that can help users get started with data visualizations. Types to consider include: string, number, date Semantic types to consider include: Location, Year, Month, Day, Date, Time, DateTime, Range, Duration, Name, Percentage, String, Number @@ -34,7 +37,8 @@ "field2": {"type": ..., "semantic_type": ..., "sort_order": null}, ... }, - "data summary": ... // a short summary of the data + "data summary": ... // a short summary of the data, + "explorative_questions": [...], // a list of explorative questions that can help users get started with data visualizations } ``` ''' @@ -76,7 +80,13 @@ "total": {"type": "number", "semantic_type": "Number", "sort_order": null}, "group": {"type": "string", "semantic_type": "Range", "sort_order": ["<10000", "10000 to 14999", "15000 to 24999", "25000 to 34999", "35000 to 49999", "50000 to 74999", "75000 to 99999", "100000 to 149999", "150000 to 199999", "200000+"]} }, - "data summary": "The dataset contains information about income distribution across different states in the USA. It includes fields for state names, regions, state IDs, percentage of total income, total income, and income groups." + "data summary": "The dataset contains information about income distribution across different states in the USA. It includes fields for state names, regions, state IDs, percentage of total income, total income, and income groups.", + "explorative_questions": [ + "What is the average income across different states?", + "What is the distribution of income across different regions?", + "What is the relationship between income and state ID?", + "What is the relationship between income and region?" + ] } ``` @@ -121,7 +131,13 @@ "sort_order": null } }, - "data_summary": "This dataset contains weather information for the cities of Seattle and Atlanta. The fields include the date, city name, and temperature readings. The 'Date' field represents dates in a string format, the 'City' field represents city names, and the 'Temperature' field represents temperature values in integer format." + "data_summary": "This dataset contains weather information for the cities of Seattle and Atlanta. The fields include the date, city name, and temperature readings. The 'Date' field represents dates in a string format, the 'City' field represents city names, and the 'Temperature' field represents temperature values in integer format.", + "explorative_questions": [ + "What is the average temperature across different cities?", + "What is the distribution of temperature across different dates?", + "What is the relationship between temperature and city?", + "What is the relationship between temperature and date?" + ] }```''' class DataLoadAgent(object): diff --git a/py-src/data_formulator/agents/agent_query_completion.py b/py-src/data_formulator/agents/agent_query_completion.py index 0dd6f494..f60a2fa9 100644 --- a/py-src/data_formulator/agents/agent_query_completion.py +++ b/py-src/data_formulator/agents/agent_query_completion.py @@ -54,7 +54,7 @@ def __init__(self, client): def run(self, data_source_metadata, query): - user_query = f"[DATA SOURCE]\n\n{json.dumps(data_source_metadata, indent=2)}\n\n[USER INPUTS]\n\n{query}\n\n[REASONING]\n" + user_query = f"[DATA SOURCE]\n\n{json.dumps(data_source_metadata, indent=2)}\n\n[USER INPUTS]\n\n{query}\n\n" logger.info(user_query) @@ -63,11 +63,11 @@ def run(self, data_source_metadata, query): ###### the part that calls open_ai response = self.client.get_completion(messages = messages) - response_content = '[REASONING]\n' + response.choices[0].message.content + response_content = response.choices[0].message.content logger.info(f"=== query completion output ===>\n{response_content}\n") - reasoning = extract_json_objects(response_content.split("[REASONING]")[1].split("[QUERY]")[0].strip())[0] + reasoning = extract_json_objects(response_content.split("[QUERY]")[0].strip())[0] output_query = response_content.split("[QUERY]")[1].strip() # Extract the query by removing the language markers diff --git a/py-src/data_formulator/data_loader/azure_blob_data_loader.py b/py-src/data_formulator/data_loader/azure_blob_data_loader.py index 1df51690..a530c987 100644 --- a/py-src/data_formulator/data_loader/azure_blob_data_loader.py +++ b/py-src/data_formulator/data_loader/azure_blob_data_loader.py @@ -369,4 +369,4 @@ def ingest_data_from_query(self, query: str, name_as: str): # Execute the query and get results as a DataFrame df = self.duck_db_conn.execute(query).df() # Use the base class's method to ingest the DataFrame - self.ingest_df_to_duckdb(df, name_as) \ No newline at end of file + self.ingest_df_to_duckdb(df, sanitize_table_name(name_as)) \ No newline at end of file diff --git a/py-src/data_formulator/data_loader/mssql_data_loader.py b/py-src/data_formulator/data_loader/mssql_data_loader.py index f2d930d3..048f7ca8 100644 --- a/py-src/data_formulator/data_loader/mssql_data_loader.py +++ b/py-src/data_formulator/data_loader/mssql_data_loader.py @@ -445,7 +445,7 @@ def ingest_data_from_query(self, query: str, name_as: str) -> pd.DataFrame: try: df = self._execute_query(query) # Use the base class's method to ingest the DataFrame - self.ingest_df_to_duckdb(df, name_as) + self.ingest_df_to_duckdb(df, sanitize_table_name(name_as)) log.info(f"Successfully ingested {len(df)} rows from custom query to {name_as}") return df except Exception as e: diff --git a/py-src/data_formulator/data_loader/mysql_data_loader.py b/py-src/data_formulator/data_loader/mysql_data_loader.py index e96bb89a..1a26d4ac 100644 --- a/py-src/data_formulator/data_loader/mysql_data_loader.py +++ b/py-src/data_formulator/data_loader/mysql_data_loader.py @@ -63,7 +63,9 @@ def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnecti try: self.duck_db_conn.execute("DETACH mysqldb;") except: - pass # Ignore if mysqldb doesn't exist # Register MySQL connection + pass # Ignore if mysqldb doesn't exist + + # Register MySQL connection self.duck_db_conn.execute(f"ATTACH '{attach_string}' AS mysqldb (TYPE mysql);") def list_tables(self, table_filter: str = None): @@ -129,4 +131,4 @@ def ingest_data_from_query(self, query: str, name_as: str) -> pd.DataFrame: # Execute the query and get results as a DataFrame df = self.duck_db_conn.execute(query).df() # Use the base class's method to ingest the DataFrame - self.ingest_df_to_duckdb(df, name_as) \ No newline at end of file + self.ingest_df_to_duckdb(df, sanitize_table_name(name_as)) \ No newline at end of file diff --git a/py-src/data_formulator/data_loader/postgresql_data_loader.py b/py-src/data_formulator/data_loader/postgresql_data_loader.py index ad5c298d..0400ace0 100644 --- a/py-src/data_formulator/data_loader/postgresql_data_loader.py +++ b/py-src/data_formulator/data_loader/postgresql_data_loader.py @@ -128,5 +128,5 @@ def ingest_data_from_query(self, query: str, name_as: str) -> pd.DataFrame: # Execute the query and get results as a DataFrame df = self.duck_db_conn.execute(query).df() # Use the base class's method to ingest the DataFrame - self.ingest_df_to_duckdb(df, name_as) + self.ingest_df_to_duckdb(df, sanitize_table_name(name_as)) return df diff --git a/py-src/data_formulator/data_loader/s3_data_loader.py b/py-src/data_formulator/data_loader/s3_data_loader.py index 1285016a..666df6af 100644 --- a/py-src/data_formulator/data_loader/s3_data_loader.py +++ b/py-src/data_formulator/data_loader/s3_data_loader.py @@ -203,4 +203,4 @@ def ingest_data_from_query(self, query: str, name_as: str): # Execute the query and get results as a DataFrame df = self.duck_db_conn.execute(query).df() # Use the base class's method to ingest the DataFrame - self.ingest_df_to_duckdb(df, name_as) \ No newline at end of file + self.ingest_df_to_duckdb(df, sanitize_table_name(name_as)) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5f82bddd..df187d81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "data_formulator" -version = "0.2.1.5" +version = "0.2.2" requires-python = ">=3.9" authors = [ diff --git a/src/app/App.tsx b/src/app/App.tsx index a5430edc..b490eff3 100644 --- a/src/app/App.tsx +++ b/src/app/App.tsx @@ -519,7 +519,7 @@ export const AppFC: FC = function AppFC(appProps) { //user is not logged in, do not show logout button //console.error(err) }); - }, []) + }, []); useEffect(() => { document.title = toolName; diff --git a/src/app/dfSlice.tsx b/src/app/dfSlice.tsx index 013abf4d..bfb78265 100644 --- a/src/app/dfSlice.tsx +++ b/src/app/dfSlice.tsx @@ -16,12 +16,11 @@ import { handleSSEMessage } from './SSEActions'; enableMapSet(); -export const generateFreshChart = (tableRef: string, chartType?: string, source: "user" | "trigger" = "user") : Chart => { - let realChartType = chartType || "?" +export const generateFreshChart = (tableRef: string, chartType: string, source: "user" | "trigger" = "user") : Chart => { return { id: `chart-${Date.now()- Math.floor(Math.random() * 10000)}`, - chartType: realChartType, - encodingMap: Object.assign({}, ...getChartChannels(realChartType).map((channel) => ({ [channel]: { channel: channel, bin: false } }))), + chartType: chartType, + encodingMap: Object.assign({}, ...getChartChannels(chartType).map((channel) => ({ [channel]: { channel: channel, bin: false } }))), tableRef: tableRef, saved: false, source: source, @@ -45,12 +44,11 @@ export interface ModelConfig { } // Define model slot types -export type ModelSlotType = 'generation' | 'hint'; +export const MODEL_SLOT_TYPES = ['generation', 'hint'] as const; +export type ModelSlotType = typeof MODEL_SLOT_TYPES[number]; -export interface ModelSlots { - generation?: string; // model id assigned to generation tasks - hint?: string; // model id assigned to hint tasks -} +// Derive ModelSlots interface from the constant +export type ModelSlots = Partial>; // Define a type for the slice state export interface DataFormulatorState { @@ -271,7 +269,7 @@ export const dataFormulatorSlice = createSlice({ // avoid resetting inputted models // state.oaiModels = state.oaiModels.filter((m: any) => m.endpoint != 'default'); - state.modelSlots = {}; + // state.modelSlots = {}; state.testedModels = []; state.tables = []; @@ -358,11 +356,13 @@ export const dataFormulatorSlice = createSlice({ }, loadTable: (state, action: PayloadAction) => { let table = action.payload; + let freshChart = generateFreshChart(table.id, '?') as Chart; state.tables = [...state.tables, table]; + state.charts = [...state.charts, freshChart]; state.conceptShelfItems = [...state.conceptShelfItems, ...getDataFieldItems(table)]; state.focusedTableId = table.id; - state.focusedChartId = undefined; + state.focusedChartId = freshChart.id; }, deleteTable: (state, action: PayloadAction) => { let tableId = action.payload; @@ -452,7 +452,7 @@ export const dataFormulatorSlice = createSlice({ }); } }, - createNewChart: (state, action: PayloadAction<{chartType?: string, tableId?: string}>) => { + createNewChart: (state, action: PayloadAction<{chartType: string, tableId: string}>) => { let chartType = action.payload.chartType; let tableId = action.payload.tableId || state.tables[0].id; let freshChart = generateFreshChart(tableId, chartType, "user") as Chart; @@ -745,6 +745,11 @@ export const dataFormulatorSlice = createSlice({ return field; } }) + + if (data["result"][0]["explorative_questions"] && data["result"][0]["explorative_questions"].length > 0) { + let table = state.tables.find(t => t.id == tableId) as DictTable; + table.explorativeQuestions = data["result"][0]["explorative_questions"] as string[]; + } } }) .addCase(fetchAvailableModels.fulfilled, (state, action) => { @@ -763,8 +768,12 @@ export const dataFormulatorSlice = createSlice({ ...state.testedModels.filter(t => !defaultModels.map((m: ModelConfig) => m.id).includes(t.id)) ] - if (state.modelSlots.generation == undefined && defaultModels.length > 0) { - state.modelSlots.generation = defaultModels[0].id; + if (defaultModels.length > 0) { + for (const slotType of MODEL_SLOT_TYPES) { + if (state.modelSlots[slotType] == undefined) { + state.modelSlots[slotType] = defaultModels[0].id; + } + } } // console.log("load model complete"); @@ -796,7 +805,7 @@ export const dfSelectors = { return modelId ? state.models.find(m => m.id === modelId) : undefined; }, getAllSlotTypes: () : ModelSlotType[] => { - return ['generation', 'hint']; + return [...MODEL_SLOT_TYPES]; }, getActiveBaseTableIds: (state: DataFormulatorState) => { let focusedTableId = state.focusedTableId; diff --git a/src/components/ComponentType.tsx b/src/components/ComponentType.tsx index 21b9a8e7..9476c10b 100644 --- a/src/components/ComponentType.tsx +++ b/src/components/ComponentType.tsx @@ -76,6 +76,7 @@ export interface DictTable { rowCount: number; // total number of rows in the full table }; anchored: boolean; // whether this table is anchored as a persistent table used to derive other tables + explorativeQuestions: string[]; // a list of (3-5) explorative questions that can help users get started with data visualizations } export function createDictTable( @@ -83,7 +84,8 @@ export function createDictTable( derive: {code: string, codeExpl: string, source: string[], dialog: any[], trigger: Trigger} | undefined = undefined, virtual: {tableId: string, rowCount: number} | undefined = undefined, - anchored: boolean = false) : DictTable { + anchored: boolean = false, + explorativeQuestions: string[] = []) : DictTable { let names = Object.keys(rows[0]) @@ -95,7 +97,8 @@ export function createDictTable( types: names.map(name => inferTypeFromValueArray(rows.map(r => r[name]))), derive, virtual, - anchored + anchored, + explorativeQuestions } } diff --git a/src/views/ChartRecBox.tsx b/src/views/ChartRecBox.tsx new file mode 100644 index 00000000..649f2cdb --- /dev/null +++ b/src/views/ChartRecBox.tsx @@ -0,0 +1,525 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { FC, useEffect, useState } from 'react' +import { useSelector, useDispatch } from 'react-redux' +import { DataFormulatorState, dfActions, dfSelectors, fetchCodeExpl, fetchFieldSemanticType, generateFreshChart } from '../app/dfSlice'; + +import { + Box, + Typography, + FormControl, + InputLabel, + Select, + MenuItem, + ListSubheader, + ListItemIcon, + ListItemText, + IconButton, + Tooltip, + TextField, + Stack, + Card, + Chip, + Autocomplete, + Menu, + SxProps, + LinearProgress, + CircularProgress, + Divider, +} from '@mui/material'; + +import React from 'react'; + +import { Channel, EncodingItem, ConceptTransformation, Chart, FieldItem, Trigger, duplicateChart, EncodingMap } from "../components/ComponentType"; + +import _ from 'lodash'; + +import '../scss/EncodingShelf.scss'; +import { createDictTable, DictTable } from "../components/ComponentType"; + +import { getUrls, resolveChartFields } from '../app/utils'; + +import AddIcon from '@mui/icons-material/Add'; + +import { AppDispatch } from '../app/store'; +import PrecisionManufacturing from '@mui/icons-material/PrecisionManufacturing'; +import { Type } from '../data/types'; +import CloseIcon from '@mui/icons-material/Close'; +import InsightsIcon from '@mui/icons-material/Insights'; + +export interface ChartRecBoxProps { + tableId: string; + placeHolderChartId: string; + sx?: SxProps; +} + +// Table selector component for ChartRecBox +const NLTableSelector: FC<{ + selectedTableIds: string[], + tables: DictTable[], + updateSelectedTableIds: (tableIds: string[]) => void, + requiredTableIds?: string[] +}> = ({ selectedTableIds, tables, updateSelectedTableIds, requiredTableIds = [] }) => { + const [anchorEl, setAnchorEl] = useState(null); + const open = Boolean(anchorEl); + + const handleClick = (event: React.MouseEvent) => { + setAnchorEl(event.currentTarget); + }; + + const handleClose = () => { + setAnchorEl(null); + }; + + const handleTableSelect = (table: DictTable) => { + if (!selectedTableIds.includes(table.id)) { + updateSelectedTableIds([...selectedTableIds, table.id]); + } + handleClose(); + }; + + return ( + + {selectedTableIds.map((tableId) => { + const isRequired = requiredTableIds.includes(tableId); + return ( + t.id == tableId)?.displayId} + size="small" + sx={{ + height: 16, + fontSize: '10px', + borderRadius: '2px', + bgcolor: isRequired ? 'rgba(25, 118, 210, 0.2)' : 'rgba(25, 118, 210, 0.1)', + color: 'rgba(0, 0, 0, 0.7)', + '& .MuiChip-label': { + pl: '4px', + pr: '6px' + } + }} + deleteIcon={isRequired ? undefined : } + onDelete={isRequired ? undefined : () => updateSelectedTableIds(selectedTableIds.filter(id => id !== tableId))} + /> + ); + })} + + + + + + + {tables + .filter(t => t.derive === undefined || t.anchored) + .map((table) => { + const isSelected = selectedTableIds.includes(table.id); + const isRequired = requiredTableIds.includes(table.id); + return ( + handleTableSelect(table)} + sx={{ + fontSize: '12px', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center' + }} + > + {table.displayId} + {isRequired && (required)} + + ); + }) + } + + + ); +}; + +export const ChartRecBox: FC = function ({ tableId, placeHolderChartId, sx }) { + const dispatch = useDispatch(); + + // reference to states + const tables = useSelector((state: DataFormulatorState) => state.tables); + const config = useSelector((state: DataFormulatorState) => state.config); + const conceptShelfItems = useSelector((state: DataFormulatorState) => state.conceptShelfItems); + const allCharts = useSelector(dfSelectors.getAllCharts); + const activeModel = useSelector(dfSelectors.getActiveModel); + + const [prompt, setPrompt] = useState(""); + const [isFormulating, setIsFormulating] = useState(false); + // Remove the randomQuestion state since we'll generate it on demand + + // Use the provided tableId and find additional available tables for multi-table operations + const currentTable = tables.find(t => t.id === tableId); + + // Remove the useEffect that sets randomQuestion + + const availableTables = tables.filter(t => t.derive === undefined || t.anchored); + const [additionalTableIds, setAdditionalTableIds] = useState([]); + + // Combine the main tableId with additional selected tables + const selectedTableIds = currentTable ? [tableId, ...additionalTableIds] : []; + + const handleTableSelectionChange = (newTableIds: string[]) => { + // Filter out the main tableId since it's always included + const additionalIds = newTableIds.filter(id => id !== tableId); + setAdditionalTableIds(additionalIds); + }; + + // Function to get a random question from the list + const getQuestion = (random: boolean = false): string => { + if (currentTable?.explorativeQuestions && currentTable.explorativeQuestions.length > 0) { + const index = random ? Math.floor(Math.random() * currentTable.explorativeQuestions.length) : 0; + return currentTable.explorativeQuestions[index]; + } + return "Show something interesting about the data"; + }; + + // Handle tab key press for auto-completion + const handleKeyDown = (event: React.KeyboardEvent) => { + if (event.key === 'Tab' && !event.shiftKey) { + event.preventDefault(); + if (prompt.trim() === "") { + setPrompt(getQuestion(false)); + } + } + }; + + const deriveDataFromNL = (instruction: string) => { + + if (selectedTableIds.length === 0 || instruction.trim() === "") { + return; + } + + if (placeHolderChartId) { + dispatch(dfActions.updateChartType({chartType: "Auto", chartId: placeHolderChartId})); + dispatch(dfActions.changeChartRunningStatus({chartId: placeHolderChartId, status: true})); + } + + const actionTables = selectedTableIds.map(id => tables.find(t => t.id === id) as DictTable); + + // Validate table selection + const firstTableId = selectedTableIds[0]; + if (!firstTableId) { + dispatch(dfActions.addMessages({ + "timestamp": Date.now(), + "type": "error", + "component": "chart builder", + "value": "No table selected for data formulation.", + })); + return; + } + + // Set formulating status without creating Auto chart + setIsFormulating(true); + + const token = String(Date.now()); + const messageBody = JSON.stringify({ + token: token, + mode: 'formulate', + input_tables: actionTables.map(t => ({ + name: t.virtual?.tableId || t.id.replace(/\.[^/.]+$/, ""), + rows: t.rows + })), + new_fields: [], // No specific fields, let AI decide + extra_prompt: instruction, + model: activeModel, + max_repair_attempts: config.maxRepairAttempts, + language: actionTables.some(t => t.virtual) ? "sql" : "python" + }); + + if (process.env.NODE_ENV !== 'production') { + console.debug("debug: messageBody", messageBody); + } + + const engine = getUrls().DERIVE_DATA; + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), config.formulateTimeoutSeconds * 1000); + + fetch(engine, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: messageBody, + signal: controller.signal + }) + .then((response) => response.json()) + .then((data) => { + setIsFormulating(false); + + if (placeHolderChartId) { + dispatch(dfActions.changeChartRunningStatus({chartId: placeHolderChartId, status: false})); + } + + if (data.results.length > 0) { + if (data["token"] === token) { + const candidates = data["results"].filter((item: any) => item["status"] === "ok"); + + if (candidates.length === 0) { + const errorMessage = data.results[0].content; + const code = data.results[0].code; + + dispatch(dfActions.addMessages({ + "timestamp": Date.now(), + "type": "error", + "component": "chart builder", + "value": `Data formulation failed, please try again.`, + "code": code, + "detail": errorMessage + })); + } else { + const candidate = candidates[0]; + const code = candidate["code"]; + const rows = candidate["content"]["rows"]; + const dialog = candidate["dialog"]; + const refinedGoal = candidate['refined_goal']; + + // Generate table ID + const genTableId = () => { + let tableSuffix = Number.parseInt((Date.now() - Math.floor(Math.random() * 10000)).toString().slice(-2)); + let tableId = `table-${tableSuffix}`; + while (tables.find(t => t.id === tableId) !== undefined) { + tableSuffix = tableSuffix + 1; + tableId = `table-${tableSuffix}`; + } + return tableId; + }; + + const candidateTableId = candidate["content"]["virtual"] + ? candidate["content"]["virtual"]["table_name"] + : genTableId(); + + // Create new table + const candidateTable = createDictTable( + candidateTableId, + rows, + undefined // No derive info for ChartRecBox - it's NL-driven without triggers + ); + + let refChart = generateFreshChart(firstTableId, 'Auto') as Chart; + refChart.source = 'trigger'; + + // Add derive info manually since ChartRecBox doesn't use triggers + candidateTable.derive = { + code: code, + codeExpl: "", + source: selectedTableIds, + dialog: dialog, + trigger: { + tableId: firstTableId, + sourceTableIds: selectedTableIds, + instruction: instruction, + chart: refChart, // No upfront chart reference + resultTableId: candidateTableId + } + }; + + if (candidate["content"]["virtual"] != null) { + candidateTable.virtual = { + tableId: candidate["content"]["virtual"]["table_name"], + rowCount: candidate["content"]["virtual"]["row_count"] + }; + } + + dispatch(dfActions.insertDerivedTables(candidateTable)); + + console.log("debug: candidateTable") + console.log(candidateTable) + + // Add missing concept items + const names = candidateTable.names; + const missingNames = names.filter(name => + !conceptShelfItems.some(field => field.name === name) + ); + + const conceptsToAdd = missingNames.map((name) => ({ + id: `concept-${name}-${Date.now()}`, + name: name, + type: "auto" as Type, + description: "", + source: "custom", + tableRef: "custom", + temporary: true, + domain: [], + } as FieldItem)); + + dispatch(dfActions.addConceptItems(conceptsToAdd)); + dispatch(fetchFieldSemanticType(candidateTable)); + dispatch(fetchCodeExpl(candidateTable)); + + // Create proper chart based on refined goal + const currentConcepts = [...conceptShelfItems.filter(c => names.includes(c.name)), ...conceptsToAdd]; + + let chartTypeMap: any = { + "line": "Line Chart", + "bar": "Bar Chart", + "point": "Scatter Plot", + "boxplot": "Boxplot" + }; + + const chartType = chartTypeMap[refinedGoal?.['chart_type']] || 'Scatter Plot'; + let newChart = generateFreshChart(candidateTable.id, chartType) as Chart; + newChart = resolveChartFields(newChart, currentConcepts, refinedGoal, candidateTable); + + console.log("debug: newChart") + console.log(newChart) + + // Create and focus the new chart directly + dispatch(dfActions.addAndFocusChart(newChart)); + + // Clean up + dispatch(dfActions.setFocusedTable(candidateTable.id)); + dispatch(dfActions.setVisPaneSize(640)); + + dispatch(dfActions.addMessages({ + "timestamp": Date.now(), + "component": "chart builder", + "type": "success", + "value": `Data formulation succeeded for: "${instruction}"` + })); + + // Clear the prompt after successful formulation + setPrompt(""); + + if (placeHolderChartId) { + dispatch(dfActions.deleteChartById(placeHolderChartId)); + } + } + } + } else { + dispatch(dfActions.addMessages({ + "timestamp": Date.now(), + "component": "chart builder", + "type": "error", + "value": "No result is returned from the data formulation agent. Please try again." + })); + + setIsFormulating(false); + } + }) + .catch((error) => { + setIsFormulating(false); + + if (error.name === 'AbortError') { + dispatch(dfActions.addMessages({ + "timestamp": Date.now(), + "component": "chart builder", + "type": "error", + "value": `Data formulation timed out after ${config.formulateTimeoutSeconds} seconds. Consider breaking down the task, using a different model or prompt, or increasing the timeout limit.`, + "detail": "Request exceeded timeout limit" + })); + } else { + dispatch(dfActions.addMessages({ + "timestamp": Date.now(), + "component": "chart builder", + "type": "error", + "value": `Data formulation failed, please try again.`, + "detail": error.message + })); + } + }); + }; + + const showTableSelector = availableTables.length > 1 && currentTable; + + return ( + + {showTableSelector && ( + + + Select additional tables: + + + + )} + + + setPrompt(event.target.value)} + onKeyDown={handleKeyDown} + slotProps={{ + inputLabel: { shrink: true }, + input: { + endAdornment: + deriveDataFromNL(prompt.trim())} + > + {isFormulating ? + + : } + + + } + }} + value={prompt} + label="Describe what you want to visualize" + placeholder={`e.g., ${getQuestion(false)}`} + fullWidth + multiline + variant="standard" + maxRows={4} + minRows={1} + /> + + + + surprise? + + + deriveDataFromNL(getQuestion(true))} + > + + + + + + + ); +}; \ No newline at end of file diff --git a/src/views/DBTableManager.tsx b/src/views/DBTableManager.tsx index dcfb6ee5..6fd17947 100644 --- a/src/views/DBTableManager.tsx +++ b/src/views/DBTableManager.tsx @@ -542,6 +542,7 @@ export const DBTableSelectionDialog: React.FC<{ buttonElement: any }> = function rowCount: dbTable.row_count, }, anchored: true, // by default, db tables are anchored + explorativeQuestions: [] } dispatch(dfActions.loadTable(table)); dispatch(fetchFieldSemanticType(table)); @@ -1396,21 +1397,21 @@ export const DataQueryForm: React.FC<{ }} /> - {queryResult?.status === "error" && + {queryResult?.status === "error" && {queryResult?.message} - } - + - } + diff --git a/src/views/DataFormulator.tsx b/src/views/DataFormulator.tsx index a76aea13..9a384782 100644 --- a/src/views/DataFormulator.tsx +++ b/src/views/DataFormulator.tsx @@ -8,6 +8,7 @@ import { useDispatch, useSelector } from "react-redux"; /* code change */ import { DataFormulatorState, dfActions, + dfSelectors, } from '../app/dfSlice' import _ from 'lodash'; @@ -43,6 +44,7 @@ import exampleImageTable from "../assets/example-image-table.png"; import { ModelSelectionButton } from './ModelSelectionDialog'; import { DBTableSelectionDialog } from './DBTableManager'; import { connectToSSE } from './SSEClient'; +import { getUrls } from '../app/utils'; //type AppProps = ConnectedProps; @@ -51,7 +53,16 @@ export const DataFormulatorFC = ({ }) => { const displayPanelSize = useSelector((state: DataFormulatorState) => state.displayPanelSize); const visPaneSize = useSelector((state: DataFormulatorState) => state.visPaneSize); const tables = useSelector((state: DataFormulatorState) => state.tables); - const selectedModelId = useSelector((state: DataFormulatorState) => state.selectedModelId); + + const models = useSelector((state: DataFormulatorState) => state.models); + const modelSlots = useSelector((state: DataFormulatorState) => state.modelSlots); + const testedModels = useSelector((state: DataFormulatorState) => state.testedModels); + + const noBrokenModelSlots= useSelector((state: DataFormulatorState) => { + const slotTypes = dfSelectors.getAllSlotTypes(); + return slotTypes.every( + slotType => state.modelSlots[slotType] !== undefined && state.testedModels.find(t => t.id == state.modelSlots[slotType])?.status != 'error'); + }); const dispatch = useDispatch(); @@ -59,6 +70,45 @@ export const DataFormulatorFC = ({ }) => { document.title = toolName; }, []); + useEffect(() => { + const findWorkingModel = async () => { + let assignedModels = models.filter(m => Object.values(modelSlots).includes(m.id)); + let unassignedModels = models.filter(m => !Object.values(modelSlots).includes(m.id)); + + // Combine both arrays: assigned models first, then unassigned models + let allModelsToTest = [...assignedModels, ...unassignedModels]; + + for (let i = 0; i < allModelsToTest.length; i++) { + let model = allModelsToTest[i]; + let isAssignedModel = i < assignedModels.length; + + const message = { + method: 'POST', + headers: { 'Content-Type': 'application/json', }, + body: JSON.stringify({ + model: model, + }), + }; + try { + const response = await fetch(getUrls().TEST_MODEL, {...message }); + const data = await response.json(); + const status = data["status"] || 'error'; + dispatch(dfActions.updateModelStatus({id: model.id, status, message: data["message"] || ""})); + // For unassigned models, break when we find a working one + if (!isAssignedModel && status == 'ok') { + break; + } + } catch (error) { + dispatch(dfActions.updateModelStatus({id: model.id, status: 'error', message: (error as Error).message || 'Failed to test model'})); + } + } + }; + + if (models.length > 0) { + findWorkingModel(); + } + }, []); + let conceptEncodingPanel = ( @@ -173,7 +223,7 @@ Totals (7 entries) 5 5 5 15 return ( - {selectedModelId == undefined ? modelSelectionDialogBox : (tables.length > 0 ? fixedSplitPane : dataUploadRequestBox)} + {!noBrokenModelSlots ? modelSelectionDialogBox : (tables.length > 0 ? fixedSplitPane : dataUploadRequestBox)} ); } \ No newline at end of file diff --git a/src/views/DataThread.tsx b/src/views/DataThread.tsx index ba22d298..fb0f8878 100644 --- a/src/views/DataThread.tsx +++ b/src/views/DataThread.tsx @@ -236,11 +236,17 @@ let SingleThreadView: FC<{ let triggers = getTriggers(leafTable, tables); let highlightedTableIds: string[] = [leafTable.id]; + + let threadOriginalTableId: string | undefined = leafTable.derive?.trigger.sourceTableIds[0]; + let triggerToFirstNewTable: Trigger | undefined = undefined; if (leafTable.derive) { // find the first table that belongs to this thread, it should not be an intermediate table that has appeared in previous threads let firstNewTableIndex = triggers.findIndex(tg => !usedIntermediateTableIds.includes(tg.tableId)); + let firstNewTableId = firstNewTableIndex != -1 ? triggers[firstNewTableIndex].tableId : leafTable.id; + + triggerToFirstNewTable = triggers.find(t => t.resultTableId == firstNewTableId); // when firstNewTableIndex is -1, it means the leaf table should be the first one to display at the top of the thread if (firstNewTableIndex == -1) { @@ -287,11 +293,6 @@ let SingleThreadView: FC<{ highlightedTableIds = focusedTableId && tableIdList.includes(focusedTableId) ? tableIdList : []; } - let originTableIdOfThread = tables.find(t => t.id == leafTable.id)?.derive?.trigger.sourceTableIds[0]; - if (originTableIdOfThread == tableIdList[0]) { - originTableIdOfThread = undefined; - } - let tableElementList = tableIdList.map((tableId, i) => { if (tableId == leafTable.id && leafTable.anchored && tableIdList.length > 1) { @@ -316,7 +317,7 @@ let SingleThreadView: FC<{ dispatch(dfActions.setFocusedTable(tableId)); // Find and set the first chart associated with this table - let firstRelatedChart = charts.find((c: Chart) => c.tableRef == tableId); + let firstRelatedChart = charts.find((c: Chart) => c.tableRef == tableId && c.source != "trigger"); if (firstRelatedChart) { dispatch(dfActions.setFocusedChart(firstRelatedChart.id)); @@ -375,11 +376,11 @@ let SingleThreadView: FC<{ onClick={() => { dispatch(dfActions.setFocusedTable(tableId)); if (focusedChart?.tableRef != tableId) { - let firstRelatedChart = charts.find((c: Chart) => c.tableRef == tableId); + let firstRelatedChart = charts.find((c: Chart) => c.tableRef == tableId && c.source != 'trigger'); if (firstRelatedChart) { dispatch(dfActions.setFocusedChart(firstRelatedChart.id)); } else { - //dispatch(dfActions.createNewChart({ tableId: tableId })); + //dispatch(dfActions.createNewChart({ tableId: tableId, chartType: '?' })); } } }}> @@ -463,7 +464,7 @@ let SingleThreadView: FC<{ { event.stopPropagation(); - dispatch(dfActions.createNewChart({ tableId: tableId })); + dispatch(dfActions.createNewChart({ tableId: tableId, chartType: '?' })); }} /> @@ -535,6 +536,8 @@ let SingleThreadView: FC<{ content = w(tableElementList, triggerCards, "") + let selectedClassName = focusedChartId == triggerToFirstNewTable?.chart?.id ? 'selected-card' : ''; + return
- {originTableIdOfThread && - - {`${tables.find(t => t.id === originTableIdOfThread)?.displayId || originTableIdOfThread}`} - - - } + {threadOriginalTableId && !tableIdList.includes(threadOriginalTableId) && + + + {`${tables.find(t => t.id === threadOriginalTableId)?.displayId || threadOriginalTableId}`} + + + {triggerToFirstNewTable && threadOriginalTableId != triggerToFirstNewTable.tableId && + + {`${tables.find(t => t.id === triggerToFirstNewTable.tableId)?.displayId || triggerToFirstNewTable.tableId}`} + + + } + {triggerToFirstNewTable && +
+ + + +
} + +
+ } {content}
@@ -630,10 +656,29 @@ const MemoizedChartObject = memo<{ }>(({ chart, table, conceptShelfItems, status, onChartClick, onDelete }) => { let visTableRows = structuredClone(table.rows); + let deleteButton = + + { + event.stopPropagation(); + onDelete(chart.id); + }}> + + - if (chart.chartType == "Auto") { - let element = - + if (['Auto', '?'].includes(chart.chartType)) { + let element = onChartClick(chart.id, table.id)} + sx={{ width: "100%", color: 'text.secondary', height: 48, display: "flex", backgroundColor: "white", position: 'relative', flexDirection: "column" }}> + {status == 'pending' ? + + : ''} + + {deleteButton} return element; } @@ -658,15 +703,7 @@ const MemoizedChartObject = memo<{ {generateChartSkeleton(chartTemplate?.icon, 48, 48, chart.chartType == 'Table' ? 1 : 0.5)} - - - { - event.stopPropagation(); - onDelete(chart.id); - }}> - - + {deleteButton}
; return element; @@ -810,6 +847,7 @@ export const DataThread: FC<{}> = function ({ }) { let view = = function ({ }) { let jumpButtons = drawerOpen ? jumpButtonsDrawerOpen : jumpButtonDrawerClosed; let carousel = ( - + = function ({ }) { {view} diff --git a/src/views/EncodingShelfCard.tsx b/src/views/EncodingShelfCard.tsx index 020d9cf4..2948d7bc 100644 --- a/src/views/EncodingShelfCard.tsx +++ b/src/views/EncodingShelfCard.tsx @@ -323,7 +323,7 @@ export const EncodingShelfCard: FC = function ({ chartId let actionTables = actionTableIds.map(id => tables.find(t => t.id == id) as DictTable); - let instruction = (chart.chartType == 'Auto' && prompt == "") ? "let's get started" : prompt; + let instruction = (['Auto'].includes(chart.chartType) && prompt == "") ? "let's get started" : prompt; if (currentTable.derive == undefined && instruction == "" && (activeFields.length > 0 && activeCustomFields.length == 0) && @@ -646,7 +646,7 @@ export const EncodingShelfCard: FC = function ({ chartId }} value={prompt} label="" - placeholder={chart.chartType == "Auto" ? "what do you want to visualize?" : "formulate data"} + placeholder={['Auto'].includes(chart.chartType) ? "what do you want to visualize?" : "formulate data"} fullWidth multiline variant="standard" @@ -728,7 +728,6 @@ export const EncodingShelfCard: FC = function ({ chartId })} - {encodingBoxGroups} @@ -738,10 +737,10 @@ export const EncodingShelfCard: FC = function ({ chartId const encodingShelfCard = ( + sx={{ padding: 1, maxWidth: "400px", display: 'flex', flexDirection: 'row', alignItems: "center", backgroundColor: trigger ? "rgba(255, 160, 122, 0.07)" : "" }}> {channelComponent} ) - return encodingShelfCard; + return encodingShelfCard ; } \ No newline at end of file diff --git a/src/views/ModelSelectionDialog.tsx b/src/views/ModelSelectionDialog.tsx index 3e286f92..844b4433 100644 --- a/src/views/ModelSelectionDialog.tsx +++ b/src/views/ModelSelectionDialog.tsx @@ -102,39 +102,7 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { }); }, []); - useEffect(() => { - const findWorkingModel = async () => { - for (let i = 0; i < models.length; i++) { - if (testedModels.find(t => t.id == models[i].id)) { - continue; - } - const model = models[i]; - const message = { - method: 'POST', - headers: { 'Content-Type': 'application/json', }, - body: JSON.stringify({ - model: model, - }), - }; - try { - const response = await fetch(getUrls().TEST_MODEL, {...message }); - const data = await response.json(); - const status = data["status"] || 'error'; - updateModelStatus(model, status, data["message"] || ""); - if (status === 'ok') { - break; - } - } catch (error) { - updateModelStatus(model, 'error', (error as Error).message || 'Failed to test model'); - } - } - }; - - if (models.length > 0 && testedModels.filter(t => t.status == 'ok').length == 0) { - findWorkingModel(); - } - }, []); - + let updateModelStatus = (model: ModelConfig, status: 'ok' | 'error' | 'testing' | 'unknown', message: string) => { dispatch(dfActions.updateModelStatus({id: model.id, status, message})); } @@ -250,7 +218,8 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { p: 1.5, border: '1px solid #e0e0e0', borderRadius: 1, - borderColor: assignedModel ? theme.palette.success.main : theme.palette.error.main + borderColor: assignedModel && getStatus(assignedModelId) == 'ok' ? theme.palette.success.main : + getStatus(assignedModelId) == 'error' ? theme.palette.error.main : theme.palette.warning.main }} > @@ -290,7 +259,9 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { )} - + {getStatus(assignedModelId) === 'ok' ? + : getStatus(assignedModelId) === 'error' ? + : } : 'Unknown model'; }} > @@ -741,12 +712,14 @@ export const ModelSelectionButton: React.FC<{}> = ({ }) => { + let notAllSlotsReady = Object.values(tempModelSlots).filter(id => id).length !== dfSelectors.getAllSlotTypes().length + || Object.values(tempModelSlots).filter(id => id).some(id => getStatus(id) !== 'ok'); + return <> = ({ }) => { {showKeys ? 'hide' : 'show'} keys )} - + let chartSelectionBox = + {Object.entries(CHART_TEMPLATES).map(([cls, templates])=>templates).flat().filter(t => t.chart != "Auto").map(t => + { + return + } )} return ( + {focusedTableId ? : null} + + + or, select a chart type + + {chartSelectionBox} ) @@ -881,7 +888,6 @@ export const VisualizationViewFC: FC = function VisualizationView let chartEditor = - let finalView = ; if (visViewMode == "gallery") {