Skip to content

Commit f75b957

Browse files
committed
almost there!
1 parent f7a1b70 commit f75b957

File tree

10 files changed

+318
-236
lines changed

10 files changed

+318
-236
lines changed

py-src/data_formulator/app.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,60 @@ def list_tables():
666666
"message": str(e)
667667
}), 500
668668

669+
def assemble_query(aggregate_fields_and_functions, group_fields, columns, table_name):
670+
"""
671+
Assembles a SELECT query string based on binning, aggregation, and grouping specifications.
672+
673+
Args:
674+
bin_fields (list): Fields to be binned into ranges
675+
aggregate_fields_and_functions (list): List of tuples (field, function) for aggregation
676+
group_fields (list): Fields to group by
677+
columns (list): All available column names
678+
679+
Returns:
680+
str: The assembled SELECT query projection part
681+
"""
682+
select_parts = []
683+
output_column_names = []
684+
685+
# Handle aggregate fields and functions
686+
for field, function in aggregate_fields_and_functions:
687+
if field is None:
688+
# Handle count(*) case
689+
if function.lower() == 'count':
690+
select_parts.append('COUNT(*) as _count')
691+
output_column_names.append('_count')
692+
elif field in columns:
693+
if function.lower() == 'count':
694+
alias = f'_count'
695+
select_parts.append(f'COUNT(*) as {alias}')
696+
output_column_names.append(alias)
697+
else:
698+
# Sanitize function name and create alias
699+
if function in ["avg", "average", "mean"]:
700+
aggregate_function = "AVG"
701+
else:
702+
aggregate_function = function.upper()
703+
704+
alias = f'{field}_{function}'
705+
select_parts.append(f'{aggregate_function}("{field}") as {alias}')
706+
output_column_names.append(alias)
707+
708+
# Handle group fields
709+
for field in group_fields:
710+
if field in columns:
711+
select_parts.append(f'"{field}"')
712+
output_column_names.append(field)
713+
# If no fields are specified, select all columns
714+
if not select_parts:
715+
select_parts = ["*"]
716+
output_column_names = columns
717+
718+
from_clause = f"FROM {table_name}"
719+
group_by_clause = f"GROUP BY {', '.join(group_fields)}" if len(group_fields) > 0 and len(aggregate_fields_and_functions) > 0 else ""
720+
721+
query = f"SELECT {', '.join(select_parts)} {from_clause} {group_by_clause}"
722+
return query, output_column_names
669723

670724
@app.route('/api/tables/sample-table', methods=['POST'])
671725
def sample_table():
@@ -674,52 +728,58 @@ def sample_table():
674728
data = request.get_json()
675729
table_id = data.get('table')
676730
sample_size = data.get('size', 1000)
677-
projection_fields = data.get('projection_fields', []) # if empty, we want to include all fields
731+
aggregate_fields_and_functions = data.get('aggregate_fields_and_functions', []) # each element is a tuple (field, function)
732+
select_fields = data.get('select_fields', []) # if empty, we want to include all fields
678733
method = data.get('method', 'random') # one of 'random', 'head', 'bottom'
679734
order_by_fields = data.get('order_by_fields', [])
680735

681-
print(f"sample_table: {table_id}, {sample_size}, {projection_fields}, {method}, {order_by_fields}")
736+
print(f"sample_table: {table_id}, {sample_size}, {aggregate_fields_and_functions}, {select_fields}, {method}, {order_by_fields}")
682737

738+
total_row_count = 0
683739
# Validate field names against table columns to prevent SQL injection
684740
with db_manager.connection(session['session_id']) as db:
685741
# Get valid column names
686742
columns = [col[0] for col in db.execute(f"DESCRIBE {table_id}").fetchall()]
687743

688744
# Filter order_by_fields to only include valid column names
689745
valid_order_by_fields = [field for field in order_by_fields if field in columns]
690-
valid_projection_fields = [field for field in projection_fields if field in columns]
746+
valid_aggregate_fields_and_functions = [
747+
field_and_function for field_and_function in aggregate_fields_and_functions
748+
if field_and_function[0] is None or field_and_function[0] in columns
749+
]
750+
valid_select_fields = [field for field in select_fields if field in columns]
691751

692-
if len(valid_projection_fields) == 0:
693-
projection_fields_str = "*"
694-
else:
695-
projection_fields_str = ", ".join(valid_projection_fields)
752+
query, output_column_names = assemble_query(valid_aggregate_fields_and_functions, valid_select_fields, columns, table_id)
696753

754+
# Modify the original query to include the count:
755+
count_query = f"SELECT *, COUNT(*) OVER () as total_count FROM ({query}) as subq LIMIT 1"
756+
result = db.execute(count_query).fetchone()
757+
total_row_count = result[-1] if result else 0
758+
759+
# Add ordering and limit to the main query
697760
if method == 'random':
698-
result = db.execute(f"SELECT {projection_fields_str} FROM {table_id} ORDER BY RANDOM() LIMIT {sample_size}").fetchall()
761+
query += f" ORDER BY RANDOM() LIMIT {sample_size}"
699762
elif method == 'head':
700763
if valid_order_by_fields:
701764
# Build ORDER BY clause with validated fields
702765
order_by_clause = ", ".join([f'"{field}"' for field in valid_order_by_fields])
703-
result = db.execute(f"SELECT {projection_fields_str} FROM {table_id} ORDER BY {order_by_clause} LIMIT {sample_size}").fetchall()
766+
query += f" ORDER BY {order_by_clause} LIMIT {sample_size}"
704767
else:
705-
result = db.execute(f"SELECT {projection_fields_str} FROM {table_id} LIMIT {sample_size}").fetchall()
768+
query += f" LIMIT {sample_size}"
706769
elif method == 'bottom':
707770
if valid_order_by_fields:
708771
# Build ORDER BY clause with validated fields in descending order
709772
order_by_clause = ", ".join([f'"{field}" DESC' for field in valid_order_by_fields])
710-
result = db.execute(f"SELECT {projection_fields_str} FROM {table_id} ORDER BY {order_by_clause} LIMIT {sample_size}").fetchall()
773+
query += f" ORDER BY {order_by_clause} LIMIT {sample_size}"
711774
else:
712-
result = db.execute(f"SELECT {projection_fields_str} FROM {table_id} ORDER BY ROWID DESC LIMIT {sample_size}").fetchall()
775+
query += f" ORDER BY ROWID DESC LIMIT {sample_size}"
713776

714-
# When using projection_fields, we need to use those as our column names
715-
if len(valid_projection_fields) > 0:
716-
column_names = valid_projection_fields
717-
else:
718-
column_names = columns
777+
result = db.execute(query).fetchall()
719778

720779
return jsonify({
721780
"status": "success",
722-
"rows": [dict(zip(column_names, row)) for row in result]
781+
"rows": [dict(zip(output_column_names, row)) for row in result],
782+
"total_row_count": total_row_count
723783
})
724784
except Exception as e:
725785
print(e)

src/app/App.tsx

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ import MuiAppBar from '@mui/material/AppBar';
4444
import { createTheme, styled, ThemeProvider } from '@mui/material/styles';
4545

4646
import PowerSettingsNewIcon from '@mui/icons-material/PowerSettingsNew';
47+
import ClearIcon from '@mui/icons-material/Clear';
48+
4749
import { DataFormulatorFC } from '../views/DataFormulator';
4850

4951
import GridViewIcon from '@mui/icons-material/GridView';
@@ -317,19 +319,25 @@ const ConfigDialog: React.FC = () => {
317319
const dispatch = useDispatch();
318320
const config = useSelector((state: DataFormulatorState) => state.config);
319321

322+
320323
const [formulateTimeoutSeconds, setFormulateTimeoutSeconds] = useState(config.formulateTimeoutSeconds);
321324
const [maxRepairAttempts, setMaxRepairAttempts] = useState(config.maxRepairAttempts);
322325

326+
const [defaultChartWidth, setDefaultChartWidth] = useState(config.defaultChartWidth);
327+
const [defaultChartHeight, setDefaultChartHeight] = useState(config.defaultChartHeight);
328+
323329
// Add check for changes
324330
const hasChanges = formulateTimeoutSeconds !== config.formulateTimeoutSeconds ||
325-
maxRepairAttempts !== config.maxRepairAttempts;
331+
maxRepairAttempts !== config.maxRepairAttempts ||
332+
defaultChartWidth !== config.defaultChartWidth ||
333+
defaultChartHeight !== config.defaultChartHeight;
326334

327335
return (
328336
<>
329337
<Button variant="text" sx={{textTransform: 'none'}} onClick={() => setOpen(true)} startIcon={<SettingsIcon />}>
330338
<Box component="span" sx={{lineHeight: 1.2, display: 'flex', flexDirection: 'column', alignItems: 'left'}}>
331-
<Box component="span" sx={{py: 0, my: 0, fontSize: '10px', mr: 'auto'}}>timeout={config.formulateTimeoutSeconds}s</Box>
332-
<Box component="span" sx={{py: 0, my: 0, fontSize: '10px', mr: 'auto'}}>max_repair={config.maxRepairAttempts}</Box>
339+
<Box component="span" sx={{py: 0, my: 0, fontSize: '10px', mr: 'auto'}}>default_timeout={config.formulateTimeoutSeconds}s</Box>
340+
<Box component="span" sx={{py: 0, my: 0, fontSize: '10px', mr: 'auto'}}>chart_size={config.defaultChartWidth}x{config.defaultChartHeight}</Box>
333341
</Box>
334342
</Button>
335343
<Dialog onClose={() => setOpen(false)} open={open}>
@@ -340,9 +348,55 @@ const ConfigDialog: React.FC = () => {
340348
display: 'flex',
341349
flexDirection: 'column',
342350
gap: 3,
343-
my: 2,
344351
maxWidth: 400
345352
}}>
353+
<Divider><Typography variant="caption">Frontend configuration</Typography></Divider>
354+
<Box sx={{ display: 'flex', alignItems: 'center', gap: 2 }}>
355+
<Box sx={{ flex: 1 }}>
356+
<TextField
357+
label="default chart width"
358+
type="number"
359+
variant="outlined"
360+
value={defaultChartWidth}
361+
onChange={(e) => {
362+
const value = parseInt(e.target.value);
363+
setDefaultChartWidth(value);
364+
}}
365+
fullWidth
366+
inputProps={{
367+
min: 100,
368+
max: 1000
369+
}}
370+
error={defaultChartWidth < 100 || defaultChartWidth > 1000}
371+
helperText={defaultChartWidth < 100 || defaultChartWidth > 1000 ?
372+
"Value must be between 100 and 1000 pixels" : ""}
373+
/>
374+
</Box>
375+
<Typography variant="caption" color="text.secondary" sx={{ mt: 1, display: 'block' }}>
376+
<ClearIcon fontSize="small" />
377+
</Typography>
378+
<Box sx={{ flex: 1 }}>
379+
<TextField
380+
label="default chart height"
381+
type="number"
382+
variant="outlined"
383+
value={defaultChartHeight}
384+
onChange={(e) => {
385+
const value = parseInt(e.target.value);
386+
setDefaultChartHeight(value);
387+
}}
388+
fullWidth
389+
inputProps={{
390+
min: 100,
391+
max: 1000
392+
}}
393+
error={defaultChartHeight < 100 || defaultChartHeight > 1000}
394+
helperText={defaultChartHeight < 100 || defaultChartHeight > 1000 ?
395+
"Value must be between 100 and 1000 pixels" : ""}
396+
/>
397+
</Box>
398+
</Box>
399+
<Divider><Typography variant="caption">Backend configuration</Typography></Divider>
346400
<Box sx={{ display: 'flex', alignItems: 'center', gap: 2 }}>
347401
<Box sx={{ flex: 1 }}>
348402
<TextField
@@ -366,9 +420,6 @@ const ConfigDialog: React.FC = () => {
366420
<Typography variant="caption" color="text.secondary" sx={{ mt: 1, display: 'block' }}>
367421
Maximum time allowed for the formulation process before timing out.
368422
</Typography>
369-
<Typography variant="caption" color="text.secondary" sx={{ mt: 1, display: 'block' }}>
370-
Smaller values (&lt;30s) make the model fails fast thus providing a smoother UI experience. Increase this value for slow models.
371-
</Typography>
372423
</Box>
373424
</Box>
374425
<Box sx={{ display: 'flex', alignItems: 'center', gap: 2 }}>
@@ -406,13 +457,18 @@ const ConfigDialog: React.FC = () => {
406457
<Button sx={{marginRight: 'auto'}} onClick={() => {
407458
setFormulateTimeoutSeconds(30);
408459
setMaxRepairAttempts(1);
460+
setDefaultChartWidth(300);
461+
setDefaultChartHeight(300);
409462
}}>Reset to default</Button>
410463
<Button onClick={() => setOpen(false)}>Cancel</Button>
411464
<Button
412465
variant={hasChanges ? "contained" : "text"}
413-
disabled={!hasChanges || isNaN(maxRepairAttempts) || maxRepairAttempts <= 0 || maxRepairAttempts > 5 || isNaN(formulateTimeoutSeconds) || formulateTimeoutSeconds <= 0 || formulateTimeoutSeconds > 3600}
466+
disabled={!hasChanges || isNaN(maxRepairAttempts) || maxRepairAttempts <= 0 || maxRepairAttempts > 5
467+
|| isNaN(formulateTimeoutSeconds) || formulateTimeoutSeconds <= 0 || formulateTimeoutSeconds > 3600
468+
|| isNaN(defaultChartWidth) || defaultChartWidth <= 0 || defaultChartWidth > 1000
469+
|| isNaN(defaultChartHeight) || defaultChartHeight <= 0 || defaultChartHeight > 1000}
414470
onClick={() => {
415-
dispatch(dfActions.setConfig({formulateTimeoutSeconds, maxRepairAttempts}));
471+
dispatch(dfActions.setConfig({formulateTimeoutSeconds, maxRepairAttempts, defaultChartWidth, defaultChartHeight}));
416472
setOpen(false);
417473
}}
418474
>

src/app/dfSlice.tsx

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ export interface DataFormulatorState {
6868
config: {
6969
formulateTimeoutSeconds: number;
7070
maxRepairAttempts: number;
71-
}
71+
defaultChartWidth: number;
72+
defaultChartHeight: number;
73+
}
7274
}
7375

7476
// Define the initial state using that type
@@ -104,6 +106,8 @@ const initialState: DataFormulatorState = {
104106
config: {
105107
formulateTimeoutSeconds: 30,
106108
maxRepairAttempts: 1,
109+
defaultChartWidth: 300,
110+
defaultChartHeight: 300,
107111
}
108112
}
109113

@@ -273,11 +277,7 @@ export const dataFormulatorSlice = createSlice({
273277

274278
state.chartSynthesisInProgress = [];
275279

276-
// avoid resetting config
277-
// state.config = {
278-
// formulateTimeoutSeconds: 30,
279-
// repairAttempts: 1,
280-
// }
280+
state.config = initialState.config;
281281
},
282282
loadState: (state, action: PayloadAction<any>) => {
283283

@@ -306,7 +306,9 @@ export const dataFormulatorSlice = createSlice({
306306

307307
state.config = savedState.config;
308308
},
309-
setConfig: (state, action: PayloadAction<{formulateTimeoutSeconds: number, maxRepairAttempts: number}>) => {
309+
setConfig: (state, action: PayloadAction<{
310+
formulateTimeoutSeconds: number, maxRepairAttempts: number,
311+
defaultChartWidth: number, defaultChartHeight: number}>) => {
310312
state.config = action.payload;
311313
},
312314
selectModel: (state, action: PayloadAction<string | undefined>) => {
@@ -548,8 +550,6 @@ export const dataFormulatorSlice = createSlice({
548550
if (field?.levels) {
549551
encoding.sortBy = JSON.stringify(field.levels);
550552
}
551-
} else if (prop == 'bin') {
552-
encoding.bin = value;
553553
} else if (prop == 'aggregate') {
554554
encoding.aggregate = value;
555555
} else if (prop == 'stack') {
@@ -573,8 +573,8 @@ export const dataFormulatorSlice = createSlice({
573573
let enc1 = chart.encodingMap[channel1];
574574
let enc2 = chart.encodingMap[channel2];
575575

576-
chart.encodingMap[channel1] = { fieldID: enc2.fieldID, aggregate: enc2.aggregate, bin: enc2.bin, sortBy: enc2.sortBy };
577-
chart.encodingMap[channel2] = { fieldID: enc1.fieldID, aggregate: enc1.aggregate, bin: enc1.bin, sortBy: enc1.sortBy };
576+
chart.encodingMap[channel1] = { fieldID: enc2.fieldID, aggregate: enc2.aggregate, sortBy: enc2.sortBy };
577+
chart.encodingMap[channel2] = { fieldID: enc1.fieldID, aggregate: enc1.aggregate, sortBy: enc1.sortBy };
578578
}
579579
},
580580
addConceptItems: (state, action: PayloadAction<FieldItem[]>) => {
@@ -621,7 +621,7 @@ export const dataFormulatorSlice = createSlice({
621621
for (let [channel, encoding] of Object.entries(chart.encodingMap)) {
622622
if (encoding.fieldID && conceptID == encoding.fieldID) {
623623
// clear the encoding
624-
chart.encodingMap[channel as Channel] = { bin: false }
624+
chart.encodingMap[channel as Channel] = { }
625625
}
626626
}
627627
}
@@ -639,7 +639,7 @@ export const dataFormulatorSlice = createSlice({
639639
for (let [channel, encoding] of Object.entries(chart.encodingMap)) {
640640
if (encoding.fieldID && conceptID == encoding.fieldID) {
641641
// clear the encoding
642-
chart.encodingMap[channel as Channel] = { bin: false }
642+
chart.encodingMap[channel as Channel] = { }
643643
}
644644
}
645645
}
@@ -770,8 +770,8 @@ export const dataFormulatorSlice = createSlice({
770770
state.selectedModelId = defaultModels[0].id;
771771
}
772772

773-
console.log("load model complete");
774-
console.log("state.models", state.models);
773+
// console.log("load model complete");
774+
// console.log("state.models", state.models);
775775
})
776776
.addCase(fetchCodeExpl.fulfilled, (state, action) => {
777777
let codeExpl = action.payload;

0 commit comments

Comments
 (0)