Skip to content

Commit 6433640

Browse files
authored
[deploy] Merge pull request #150 from microsoft/dev
Dev
2 parents 9363049 + d143354 commit 6433640

File tree

10 files changed

+53
-43
lines changed

10 files changed

+53
-43
lines changed

py-src/data_formulator/agent_routes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,10 @@ def process_data_on_load_request():
181181
client = get_client(content['model'])
182182

183183
logger.info(f" model: {content['model']}")
184+
185+
conn = db_manager.get_connection(session['session_id'])
186+
agent = DataLoadAgent(client=client, conn=conn)
184187

185-
agent = DataLoadAgent(client=client)
186188
candidates = agent.run(content["input_data"])
187189

188190
candidates = [c['content'] for c in candidates if c['status'] == 'ok']

py-src/data_formulator/agents/agent_data_load.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import json
55

66
from data_formulator.agents.agent_utils import extract_json_objects, generate_data_summary
7+
from data_formulator.agents.agent_sql_data_transform import sanitize_table_name, get_sql_table_statistics_str
8+
79
import logging
810

911
logger = logging.getLogger(__name__)
@@ -124,12 +126,18 @@
124126

125127
class DataLoadAgent(object):
126128

127-
def __init__(self, client):
129+
def __init__(self, client, conn):
128130
self.client = client
131+
self.conn = conn
129132

130133
def run(self, input_data, n=1):
131134

132-
data_summary = generate_data_summary([input_data], include_data_samples=True, field_sample_size=30)
135+
if input_data['virtual']:
136+
table_name = sanitize_table_name(input_data['name'])
137+
table_summary_str = get_sql_table_statistics_str(self.conn, table_name, row_sample_size=5, field_sample_size=30)
138+
data_summary = f"[TABLE {table_name}]\n\n{table_summary_str}"
139+
else:
140+
data_summary = generate_data_summary([input_data], include_data_samples=True, field_sample_size=30)
133141

134142
user_query = f"[DATA]\n\n{data_summary}\n\n[OUTPUT]"
135143

py-src/data_formulator/agents/agent_sql_data_transform.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,20 +300,24 @@ def followup(self, input_tables, dialog, output_fields: list[str], new_instructi
300300
return self.process_gpt_sql_response(response, messages)
301301

302302

303-
def get_sql_table_statistics_str(conn, table_name: str) -> str:
303+
def get_sql_table_statistics_str(conn, table_name: str,
304+
row_sample_size: int = 5, # number of rows to be sampled in the sample data part
305+
field_sample_size: int = 7, # number of example values for each field to be sampled
306+
max_val_chars: int = 140 # max number of characters to be shown for each example value
307+
) -> str:
304308
"""Get a string representation of the table statistics"""
305309

306310
table_name = sanitize_table_name(table_name)
307311

308312
# Get column information
309313
columns = conn.execute(f"DESCRIBE {table_name}").fetchall()
310-
sample_data = conn.execute(f"SELECT * FROM {table_name} LIMIT 5").fetchall()
314+
sample_data = conn.execute(f"SELECT * FROM {table_name} LIMIT {row_sample_size}").fetchall()
311315

312316
# Format sample data as pipe-separated string
313317
col_names = [col[0] for col in columns]
314318
formatted_sample_data = "| " + " | ".join(col_names) + " |\n"
315319
for i, row in enumerate(sample_data):
316-
formatted_sample_data += f"{i}| " + " | ".join(str(val) for val in row) + " |\n"
320+
formatted_sample_data += f"{i}| " + " | ".join(str(val)[:max_val_chars]+ "..." if len(str(val)) > max_val_chars else str(val) for val in row) + " |\n"
317321

318322
col_metadata_list = []
319323
for col in columns:
@@ -364,12 +368,12 @@ def get_sql_table_statistics_str(conn, table_name: str) -> str:
364368
(SELECT DISTINCT {quoted_col_name}
365369
FROM {table_name}
366370
WHERE {quoted_col_name} IS NOT NULL
367-
LIMIT 5)
371+
LIMIT {field_sample_size})
368372
"""
369373

370374
sample_values = conn.execute(query_for_sample_values).fetchall()
371375

372-
stats_dict['sample_values'] = sample_values
376+
stats_dict['sample_values'] = [str(val)[:max_val_chars]+ "..." if len(str(val)) > max_val_chars else str(val) for val in sample_values]
373377

374378
col_metadata_list.append({
375379
"column": col_name,

py-src/data_formulator/agents/agent_utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def dedup_data_transform_candidates(candidates):
181181
return [items[0] for _, items in candidate_groups.items()]
182182

183183

184-
def get_field_summary(field_name, df, field_sample_size):
184+
def get_field_summary(field_name, df, field_sample_size, max_val_chars=100):
185185
try:
186186
values = sorted([x for x in list(set(df[field_name].values)) if x != None])
187187
except:
@@ -196,11 +196,22 @@ def get_field_summary(field_name, df, field_sample_size):
196196
else:
197197
val_sample = values[:int(sample_size / 2)] + ["..."] + values[-(sample_size - int(sample_size / 2)):]
198198

199-
val_str = ', '.join([str(s) if ',' not in str(s) else f'"{str(s)}"' for s in val_sample])
199+
def sample_val_cap(val):
200+
if len(str(val)) > max_val_chars:
201+
s = str(val)[:max_val_chars] + "..."
202+
else:
203+
s = str(val)
204+
205+
if ',' in s:
206+
s = f'"{s}"'
207+
208+
return s
209+
210+
val_str = ', '.join([sample_val_cap(str(s)) for s in val_sample])
200211

201212
return f"{field_name} -- type: {df[field_name].dtype}, values: {val_str}"
202213

203-
def generate_data_summary(input_tables, include_data_samples=True, field_sample_size=7):
214+
def generate_data_summary(input_tables, include_data_samples=True, field_sample_size=7, max_val_chars=140):
204215

205216
input_table_names = [f'{string_to_py_varname(t["name"])}' for t in input_tables]
206217

@@ -209,7 +220,7 @@ def generate_data_summary(input_tables, include_data_samples=True, field_sample_
209220
field_summaries = []
210221
for input_data in input_tables:
211222
df = pd.DataFrame(input_data['rows'])
212-
s = '\n\t'.join([get_field_summary(fname, df, field_sample_size) for fname in list(df.columns.values)])
223+
s = '\n\t'.join([get_field_summary(fname, df, field_sample_size, max_val_chars) for fname in list(df.columns.values)])
213224
field_summaries.append(s)
214225

215226
table_field_summaries = [f'table_{i} ({input_table_names[i]}) fields:\n\t{s}' for i, s in enumerate(field_summaries)]

py-src/data_formulator/agents/client_utils.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,12 @@ def get_completion(self, messages):
6464
# Configure LiteLLM
6565

6666
if self.endpoint == "openai":
67-
68-
print("--------------------------------")
69-
print(f"self.params: {self.params}")
70-
print(f"self.model: {self.model}")
71-
print(f"self.endpoint: {self.endpoint}")
72-
print(f"self.params['api_key']: {self.params.get('api_key', 'None')}")
73-
print(f"self.params['api_base']: {self.params.get('api_base', 'None')}")
74-
print(f"self.params['api_version']: {self.params.get('api_version', 'None')}")
75-
print("--------------------------------")
76-
77-
7867
client = openai.OpenAI(
79-
base_url=self.params.get("api_base", 'placeholder'),
80-
api_key=self.params.get("api_key", 'placeholder'),
68+
base_url=self.params.get("api_base", None),
69+
api_key=self.params.get("api_key", ""),
8170
timeout=120
8271
)
8372

84-
85-
print("--------------------------------")
86-
print(f"client: {client}")
87-
print("--------------------------------")
88-
8973
completion_params = {
9074
"model": self.model,
9175
"messages": messages,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "data_formulator"
7-
version = "0.2.1"
7+
version = "0.2.1.1"
88

99
requires-python = ">=3.9"
1010
authors = [

src/app/dfSlice.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ export const fetchFieldSemanticType = createAsyncThunk(
163163
headers: { 'Content-Type': 'application/json', },
164164
body: JSON.stringify({
165165
token: Date.now(),
166-
input_data: {name: table.id, rows: table.rows},
166+
input_data: {name: table.id, rows: table.rows, virtual: table.virtual ? true : false},
167167
model: dfSelectors.getActiveModel(state)
168168
}),
169169
};

src/app/utils.tsx

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,14 @@ export const assembleVegaChart = (
319319
let sortedValues = JSON.parse(encoding.sortBy)['values'];
320320
encodingObj['sort'] = sortOrder == "ascending" ? sortedValues : sortedValues.reverse();
321321

322-
// // special hack: ensure stack bar and stacked area charts are ordered correctly
323-
// if (channel == 'color' && (vgObj['mark'] == 'bar' || vgObj['mark'] == 'area')) {
324-
// vgObj['encoding']['order'] = {'values': sortedValues};
325-
// }
322+
// special hack: ensure stack bar and stacked area charts are ordered correctly
323+
if (channel == 'color' && (vgObj['mark'] == 'bar' || vgObj['mark'] == 'area')) {
324+
// this is a very interesting hack, it leverages the hidden derived field name used in compiled Vega script to
325+
// handle order of stack bar and stacked area charts
326+
vgObj['encoding']['order'] = {
327+
"field": `color_${field?.name}_sort_index`,
328+
}
329+
}
326330
} catch {
327331
console.warn(`sort error > ${encoding.sortBy}`)
328332
}

src/views/DBTableManager.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ import { useDispatch, useSelector } from 'react-redux';
5353
import { dfActions } from '../app/dfSlice';
5454
import { alpha } from '@mui/material';
5555
import { DataFormulatorState } from '../app/dfSlice';
56+
import { fetchFieldSemanticType } from '../app/dfSlice';
57+
import { AppDispatch } from '../app/store';
5658

5759
export const handleDBDownload = async (sessionId: string) => {
5860
try {
@@ -253,7 +255,7 @@ export const DBTableManager: React.FC = () => {
253255

254256
export const DBTableSelectionDialog: React.FC<{ buttonElement: any }> = function DBTableSelectionDialog({ buttonElement }) {
255257

256-
const dispatch = useDispatch();
258+
const dispatch = useDispatch<AppDispatch>();
257259
const sessionId = useSelector((state: DataFormulatorState) => state.sessionId);
258260

259261
const [tableDialogOpen, setTableDialogOpen] = useState<boolean>(false);
@@ -471,6 +473,7 @@ export const DBTableSelectionDialog: React.FC<{ buttonElement: any }> = function
471473
anchored: true, // by default, db tables are anchored
472474
}
473475
dispatch(dfActions.loadTable(table));
476+
dispatch(fetchFieldSemanticType(table));
474477
setTableDialogOpen(false);
475478
}
476479

src/views/VisualizationView.tsx

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,6 @@ export const ChartEditorFC: FC<{ cachedCandidates: DictTable[],
463463
return renderTableChart(chart, conceptShelfItems, visTableRows);
464464
}
465465

466-
console.log('assembled chart');
467-
console.log(chart.chartType);
468-
console.log(chart.encodingMap);
469-
console.log(visTableRows.slice(0, 10));
470466

471467
let element = <></>;
472468
if (!chart || !checkChartAvailabilityOnPreparedData(chart, conceptShelfItems, visTableRows)) {
@@ -478,8 +474,6 @@ export const ChartEditorFC: FC<{ cachedCandidates: DictTable[],
478474
element = <Box id={id} key={`focused-chart`} ></Box>
479475

480476
let assembledChart = assembleVegaChart(chart.chartType, chart.encodingMap, conceptShelfItems, visTableRows, 48, true);
481-
console.log('assembled chart');
482-
console.log(assembledChart);
483477

484478
assembledChart['resize'] = true;
485479
assembledChart['config'] = {

0 commit comments

Comments
 (0)