Skip to content

Commit 3b0df51

Browse files
committed
halfway through working out new concept derivation function
1 parent d554110 commit 3b0df51

14 files changed

+457
-500
lines changed

py-src/data_formulator/agents/agent_sql_data_rec.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from data_formulator.agents.agent_utils import extract_json_objects, extract_code_from_gpt_response
77
from data_formulator.agents.agent_sql_data_transform import get_sql_table_statistics_str
88

9+
import random
10+
import string
11+
912
import traceback
1013

1114

@@ -153,13 +156,32 @@ def process_gpt_response(self, input_tables, messages, response):
153156
code_str = code_blocks[-1]
154157

155158
try:
156-
query_output = self.conn.execute(code_str).fetch_df()
159+
random_suffix = ''.join(random.choices(string.ascii_lowercase, k=4))
160+
table_name = f"view_{random_suffix}"
161+
162+
create_query = f"CREATE VIEW IF NOT EXISTS {table_name} AS {code_str}"
163+
self.conn.execute(create_query)
164+
self.conn.commit()
165+
166+
# Check how many rows are in the table
167+
row_count = self.conn.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
168+
169+
# Only limit to 5000 if there are more rows
170+
if row_count > 5000:
171+
query_output = self.conn.execute(f"SELECT * FROM {table_name} LIMIT 5000").fetch_df()
172+
else:
173+
query_output = self.conn.execute(f"SELECT * FROM {table_name}").fetch_df()
174+
self.conn.execute(f"DROP VIEW {table_name}")
157175

158176
result = {
159177
"status": "ok",
160178
"code": code_str,
161179
"content": {
162180
'rows': query_output.to_dict('records'),
181+
'virtual': {
182+
'table_name': table_name,
183+
'row_count': row_count
184+
} if row_count > 5000 else None
163185
},
164186
}
165187
except Exception as e:

py-src/data_formulator/agents/agent_sql_data_transform.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Licensed under the MIT License.
33

44
import json
5-
import time
65
import random
76
import string
87

@@ -201,9 +200,8 @@ def process_gpt_sql_response(self, response, messages):
201200

202201
try:
203202
# Generate unique table name directly with timestamp and random suffix
204-
timestamp = int(time.time())
205-
random_suffix = ''.join(random.choices(string.ascii_lowercase, k=5))
206-
table_name = f"result_{timestamp}_{random_suffix}"
203+
random_suffix = ''.join(random.choices(string.ascii_lowercase, k=4))
204+
table_name = f"view_{random_suffix}"
207205

208206
create_query = f"CREATE VIEW IF NOT EXISTS {table_name} AS {query_str}"
209207
self.conn.execute(create_query)

py-src/data_formulator/app.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from vega_datasets import data as vega_data
2828

2929
from data_formulator.agents.agent_concept_derive import ConceptDeriveAgent
30+
from data_formulator.agents.agent_py_concept_derive import PyConceptDeriveAgent
31+
3032
from data_formulator.agents.agent_py_data_transform import PythonDataTransformationAgent
3133
from data_formulator.agents.agent_sql_data_transform import SQLDataTransformationAgent
3234
from data_formulator.agents.agent_py_data_rec import PythonDataRecAgent
@@ -364,6 +366,31 @@ def derive_concept_request():
364366
return response
365367

366368

369+
@app.route('/api/derive-py-concept', methods=['GET', 'POST'])
370+
def derive_py_concept():
371+
372+
if request.is_json:
373+
app.logger.info("# code query: ")
374+
content = request.get_json()
375+
token = content["token"]
376+
377+
client = get_client(content['model'])
378+
379+
app.logger.info(f" model: {content['model']}")
380+
agent = PyConceptDeriveAgent(client=client)
381+
382+
#print(content["input_data"])
383+
384+
results = agent.run(content["input_data"], [f['name'] for f in content["input_fields"]],
385+
content["output_name"], content["description"])
386+
387+
response = flask.jsonify({ "status": "ok", "token": token, "results": results })
388+
else:
389+
response = flask.jsonify({ "token": -1, "status": "error", "results": [] })
390+
391+
response.headers.add('Access-Control-Allow-Origin', '*')
392+
return response
393+
367394
@app.route('/api/clean-data', methods=['GET', 'POST'])
368395
def clean_data_request():
369396

@@ -612,22 +639,20 @@ def list_tables():
612639
sample_rows = db.execute(f"SELECT * FROM {table_name} LIMIT 1000").fetchall()
613640

614641
# Check if this is a view or a table
615-
is_view = False
616642
try:
617-
# In most SQL databases, views are listed in a system table
618-
# For DuckDB, we can check if it's a view by querying the system tables
619-
view_check = db.execute(f"SELECT * FROM duckdb_views() WHERE view_name = '{table_name}'").fetchone()
620-
is_view = view_check is not None
621-
except Exception:
643+
# Get both view existence and source in one query
644+
view_info = db.execute(f"SELECT view_name, sql FROM duckdb_views() WHERE view_name = '{table_name}'").fetchone()
645+
view_source = view_info[1] if view_info else None
646+
except Exception as e:
622647
# If the query fails, assume it's a regular table
623-
pass
648+
view_source = None
624649

625650
result.append({
626651
"name": table_name,
627652
"columns": [{"name": col[0], "type": col[1]} for col in columns],
628653
"row_count": row_count,
629654
"sample_rows": [dict(zip([col[0] for col in columns], row)) for row in sample_rows],
630-
"is_view": is_view,
655+
"view_source": view_source
631656
})
632657

633658
return jsonify({
@@ -652,6 +677,8 @@ def sample_table():
652677
projection_fields = data.get('projection_fields', []) # if empty, we want to include all fields
653678
method = data.get('method', 'random') # one of 'random', 'head', 'bottom'
654679
order_by_fields = data.get('order_by_fields', [])
680+
681+
print(f"sample_table: {table_id}, {sample_size}, {projection_fields}, {method}, {order_by_fields}")
655682

656683
# Validate field names against table columns to prevent SQL injection
657684
with db_manager.connection(session['session_id']) as db:
@@ -817,7 +844,7 @@ def create_table():
817844

818845
@app.route('/api/tables/delete-table', methods=['POST'])
819846
def drop_table():
820-
"""Drop a table"""
847+
"""Drop a table or view"""
821848
try:
822849
data = request.get_json()
823850
table_name = data.get('table_name')
@@ -826,11 +853,14 @@ def drop_table():
826853
return jsonify({"status": "error", "message": "No table name provided"}), 400
827854

828855
with db_manager.connection(session['session_id']) as db:
856+
# First try to drop it as a view
857+
db.execute(f"DROP VIEW IF EXISTS {table_name}")
858+
# Then try to drop it as a table
829859
db.execute(f"DROP TABLE IF EXISTS {table_name}")
830860

831861
return jsonify({
832862
"status": "success",
833-
"message": f"Table {table_name} dropped"
863+
"message": f"Table/view {table_name} dropped"
834864
})
835865

836866
except Exception as e:

src/app/dfSlice.tsx

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ export interface DataFormulatorState {
4343
testedModels: {id: string, status: 'ok' | 'error' | 'testing' | 'unknown', message: string}[];
4444

4545
tables : DictTable[];
46+
extTables: {
47+
baseTableRef: string,
48+
rows: any[],
49+
virtualTableRef?: string, // whether this is a virtual table
50+
}[]; // extensions to base tables with derived tables
4651
charts: Chart[];
4752

4853
activeChallenges: {tableId: string, challenges: { text: string; difficulty: 'easy' | 'medium' | 'hard'; }[]}[];
@@ -79,6 +84,7 @@ const initialState: DataFormulatorState = {
7984
testedModels: [],
8085

8186
tables: [],
87+
extTables: [],
8288
charts: [],
8389

8490
activeChallenges: [],
@@ -261,6 +267,7 @@ export const dataFormulatorSlice = createSlice({
261267
state.testedModels = [];
262268

263269
state.tables = [];
270+
state.extTables = [];
264271
state.charts = [];
265272
state.activeChallenges = [];
266273

@@ -291,6 +298,7 @@ export const dataFormulatorSlice = createSlice({
291298

292299
//state.table = undefined;
293300
state.tables = savedState.tables || [];
301+
state.extTables = savedState.extTables || [];
294302
state.charts = savedState.charts || [];
295303

296304
state.activeChallenges = savedState.activeChallenges || [];
@@ -370,6 +378,37 @@ export const dataFormulatorSlice = createSlice({
370378
addChallenges: (state, action: PayloadAction<{tableId: string, challenges: { text: string; difficulty: 'easy' | 'medium' | 'hard'; }[]}>) => {
371379
state.activeChallenges = [...state.activeChallenges, action.payload];
372380
},
381+
setExtTables: (state, action: PayloadAction<{baseTableRef: string, rows: any[], virtualTableRef?: string}>) => {
382+
let extTable = action.payload;
383+
384+
let baseTable = state.tables.find(t => t.id == extTable.baseTableRef) as DictTable;
385+
let existingExtTable = state.extTables.find(t => t.baseTableRef == extTable.baseTableRef);
386+
387+
// we want to extend base rows with the new rows
388+
let baseRows = existingExtTable ? existingExtTable.rows : baseTable.rows;
389+
390+
if (existingExtTable) {
391+
for (let i = 0; i < extTable.rows.length; i++) {
392+
extTable.rows[i] = {...extTable.rows[i], ...baseRows[i]};
393+
}
394+
}
395+
396+
if (state.extTables.some(t => t.baseTableRef == extTable.baseTableRef)) {
397+
state.extTables = state.extTables.map(t => t.baseTableRef == extTable.baseTableRef ? {...t, rows: extTable.rows} : t);
398+
} else {
399+
state.extTables = [...state.extTables, extTable];
400+
}
401+
},
402+
removeExtTable: (state, action: PayloadAction<string>) => {
403+
let baseTableRef = action.payload;
404+
state.extTables = state.extTables.filter(t => t.baseTableRef != baseTableRef);
405+
},
406+
updateExtTableFieldName: (state, action: PayloadAction<{baseTableRef: string, oldName: string, newName: string}>) => {
407+
let baseTableRef = action.payload.baseTableRef;
408+
let oldName = action.payload.oldName;
409+
let newName = action.payload.newName;
410+
state.extTables = state.extTables.map(t => t.baseTableRef == baseTableRef ? {...t, rows: t.rows.map(r => r[oldName] = newName)} : t);
411+
},
373412
createNewChart: (state, action: PayloadAction<{chartType?: string, tableId?: string}>) => {
374413
let chartType = action.payload.chartType;
375414
let tableId = action.payload.tableId || state.tables[0].id;
@@ -542,7 +581,14 @@ export const dataFormulatorSlice = createSlice({
542581
&& Object.entries(chart.encodingMap).some(([channel, encoding]) => encoding.fieldID && conceptID == encoding.fieldID))) {
543582
console.log("cannot delete!")
544583
} else {
545-
state.conceptShelfItems = state.conceptShelfItems.filter(field => field.id != conceptID);
584+
585+
let field = state.conceptShelfItems.find(f => f.id == conceptID);
586+
if (field?.source == "derived") {
587+
// delete generated column from the derived table
588+
state.extTables = state.extTables.map(t => t.baseTableRef == field.tableRef ? {...t, rows: t.rows.map(r => delete r[field.name])} : t);
589+
}
590+
state.conceptShelfItems = state.conceptShelfItems.filter(f => f.id != conceptID);
591+
546592
for (let chart of state.charts) {
547593
for (let [channel, encoding] of Object.entries(chart.encodingMap)) {
548594
if (encoding.fieldID && conceptID == encoding.fieldID) {

src/app/utils.tsx

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ export function getUrls() {
3434

3535
// these functions involves openai models
3636
DERIVE_CONCEPT_URL: `/api/derive-concept-request`,
37+
DERIVE_PY_CONCEPT: `/api/derive-py-concept`,
38+
3739
SORT_DATA_URL: `/api/codex-sort-request`,
3840
CLEAN_DATA_URL: `/api/clean-data`,
3941

@@ -127,7 +129,45 @@ export function runCodeOnInputListsInVM(
127129
return ioPairs;
128130
}
129131

130-
export function baseTableToExtTable(table: any[], derivedFields: FieldItem[], allFields: FieldItem[]) {
132+
export function prepVisTable(table: any[], derivedFields: FieldItem[], allFields: FieldItem[], encodingMap: EncodingMap) {
133+
134+
let binningFields = []
135+
let aggregateFields = []
136+
let groupByFields = []
137+
138+
for (const [channel, encoding] of Object.entries(encodingMap)) {
139+
const field = encoding.fieldID ? _.find(allFields, (f) => f.id === encoding.fieldID) : undefined;
140+
if (field) {
141+
if (encoding.bin) {
142+
binningFields.push(field.name);
143+
}
144+
}
145+
if (encoding.aggregate) {
146+
aggregateFields.push([field?.name, encoding.aggregate]);
147+
} else {
148+
if (field) {
149+
groupByFields.push(field.name);
150+
}
151+
}
152+
}
153+
154+
///// TODOTODO:
155+
156+
return [];
157+
}
158+
159+
160+
161+
export function baseTableToExtTable(table: DictTable, extTable?: {rows: any[], baseTableRef: string, virtualTableRef?: string}) {
162+
// derive fields from derivedFields from the original table
163+
if (extTable) {
164+
return structuredClone(extTable.rows);
165+
} else {
166+
return structuredClone(table.rows);
167+
}
168+
}
169+
170+
export function baseTableToExtTableOld(table: any[], derivedFields: FieldItem[], allFields: FieldItem[]) {
131171
// derive fields from derivedFields from the original table
132172

133173
if (table.length == 0) {

0 commit comments

Comments
 (0)