Skip to content

Commit 06a22ff

Browse files
committed
various fixes to support working with more than 1 table
1 parent 97fe419 commit 06a22ff

File tree

5 files changed

+78
-20
lines changed

5 files changed

+78
-20
lines changed

py-src/data_formulator/agents/agent_data_transform_v2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,16 @@ def process_gpt_response(self, input_tables, messages, response):
258258
return candidates
259259

260260

261-
def run(self, input_tables, description, expected_fields: list[str], n=1):
261+
def run(self, input_tables, description, expected_fields: list[str], prev_messages: list[dict] = [], n=1):
262+
263+
if len(prev_messages) > 0:
264+
logger.info("=== Previous messages ===>")
265+
formatted_prev_messages = ""
266+
for m in prev_messages:
267+
if m['role'] != 'system':
268+
formatted_prev_messages += f"{m['role']}: \n\n\t{m['content']}\n\n"
269+
logger.info(formatted_prev_messages)
270+
prev_messages = [{"role": "user", "content": '[Previous Messages] Here are the previous messages for your reference:\n\n' + formatted_prev_messages}]
262271

263272
data_summary = generate_data_summary(input_tables, include_data_samples=True)
264273

@@ -272,6 +281,7 @@ def run(self, input_tables, description, expected_fields: list[str], n=1):
272281
logger.info(user_query)
273282

274283
messages = [{"role":"system", "content": self.system_prompt},
284+
*prev_messages,
275285
{"role":"user","content": user_query}]
276286

277287
response = completion_response_wrapper(self.client, messages, n)

py-src/data_formulator/app.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,11 @@ def derive_data():
425425
new_fields = content["new_fields"]
426426
instruction = content["extra_prompt"]
427427

428+
if "additional_messages" in content:
429+
prev_messages = content["additional_messages"]
430+
else:
431+
prev_messages = []
432+
428433
print("spec------------------------------")
429434
print(new_fields)
430435
print(instruction)
@@ -439,7 +444,7 @@ def derive_data():
439444
results = agent.run(input_tables, instruction)
440445
else:
441446
agent = DataTransformationAgentV2(client=client)
442-
results = agent.run(input_tables, instruction, [field['name'] for field in new_fields])
447+
results = agent.run(input_tables, instruction, [field['name'] for field in new_fields], prev_messages)
443448

444449
repair_attempts = 0
445450
while results[0]['status'] == 'error' and repair_attempts == 0: # only try once

src/views/ConceptShelf.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ export const ConceptShelf: FC<ConceptShelfProps> = function ConceptShelf() {
198198
}
199199
}}
200200
>
201-
click to expand fields
201+
{`add fields from ${groupName} for tables joins`}
202202
</Button>
203203
)}
204204
<Box

src/views/EncodingShelfCard.tsx

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ let selectBaseTables = (activeFields: FieldItem[], conceptShelfItems: FieldItem[
8383
baseTables.push(...tablesToAdd.filter(t => !baseTables.map(t2 => t2.id).includes(t.id)));
8484
}
8585

86+
console.log("selectBaseTables baseTables");
87+
console.log(baseTables);
88+
8689
return baseTables;
8790
}
8891

@@ -199,7 +202,6 @@ export const EncodingShelfCard: FC<EncodingShelfCardProps> = function ({ chartId
199202
// reference to states
200203
const tables = useSelector((state: DataFormulatorState) => state.tables);
201204
const charts = useSelector((state: DataFormulatorState) => state.charts);
202-
const betaMode = useSelector((state: DataFormulatorState) => state.betaMode);
203205

204206
let activeModel = useSelector(dfSelectors.getActiveModel);
205207

@@ -293,16 +295,46 @@ export const EncodingShelfCard: FC<EncodingShelfCardProps> = function ({ chartId
293295
let engine = getUrls().SERVER_DERIVE_DATA_URL;
294296

295297
if (mode == "formulate" && currentTable.derive?.dialog) {
296-
messageBody = JSON.stringify({
297-
token: token,
298-
mode,
299-
input_tables: baseTables.map(t => {return { name: t.id.replace(/\.[^/.]+$/ , ""), rows: t.rows }}),
300-
output_fields: activeBaseFields.map(f => { return {name: f.name} }),
301-
dialog: currentTable.derive?.dialog,
302-
new_instruction: instruction,
303-
model: activeModel
304-
})
305-
engine = getUrls().SERVER_REFINE_DATA_URL;
298+
let sourceTableIds = currentTable.derive?.source;
299+
let baseTableIds = baseTables.map(t => t.id);
300+
301+
console.log("sourceTableIds ---- and ---- baseTableIds");
302+
console.log(sourceTableIds);
303+
console.log(baseTableIds);
304+
305+
// Compare if source and base table IDs are different
306+
if (!sourceTableIds.every(id => baseTableIds.includes(id)) ||
307+
!baseTableIds.every(id => sourceTableIds.includes(id))) {
308+
309+
let additionalMessages = currentTable.derive.dialog;
310+
311+
console.log("in here");
312+
console.log(additionalMessages);
313+
314+
// in this case, because table ids has changed, we need to use the additional messages and reformulate
315+
messageBody = JSON.stringify({
316+
token: token,
317+
mode,
318+
input_tables: baseTables.map(t => {return { name: t.id.replace(/\.[^/.]+$/ , ""), rows: t.rows }}),
319+
new_fields: activeBaseFields.map(f => { return {name: f.name} }),
320+
extra_prompt: instruction,
321+
model: activeModel,
322+
additional_messages: additionalMessages
323+
});
324+
engine = getUrls().SERVER_DERIVE_DATA_URL;
325+
} else {
326+
messageBody = JSON.stringify({
327+
token: token,
328+
mode,
329+
input_tables: baseTables.map(t => {return { name: t.id.replace(/\.[^/.]+$/ , ""), rows: t.rows }}),
330+
output_fields: activeBaseFields.map(f => { return {name: f.name} }),
331+
dialog: currentTable.derive?.dialog,
332+
new_instruction: instruction,
333+
model: activeModel
334+
})
335+
engine = getUrls().SERVER_REFINE_DATA_URL;
336+
}
337+
306338
}
307339

308340
let message = {

src/views/EncodingShelfThread.tsx

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ export const EncodingShelfThread: FC<EncodingShelfThreadProps> = function ({ cha
182182

183183
let triggerTable = tables.find(t => t.id == trigger.tableId) as DictTable;
184184

185-
let baseTables = tables.filter(t => trigger.sourceTableIds.includes(t.id));
185+
let baseTables = trigger.sourceTableIds.map(id => tables.find(t => t.id == id)).filter(t => t != undefined);
186+
186187
let overrideTableId = trigger.resultTableId;
187188

188189
if (baseTables.length == 0) {
@@ -227,12 +228,22 @@ export const EncodingShelfThread: FC<EncodingShelfThreadProps> = function ({ cha
227228
token: token,
228229
mode,
229230
input_tables: baseTables.map(t => {return { name: t.id.replace(/\.[^/.]+$/ , ""), rows: t.rows }}),
230-
output_fields: activeBaseFields.map(f => { return {name: f.name } }),
231-
dialog: triggerTable.derive?.dialog,
232-
new_instruction: prompt,
231+
new_fields: activeBaseFields.map(f => { return {name: f.name} }),
232+
extra_prompt: prompt,
233+
additional_messages: triggerTable.derive?.dialog,
233234
model: activeModel
234-
})
235-
engine = getUrls().SERVER_REFINE_DATA_URL;
235+
})
236+
engine = getUrls().SERVER_DERIVE_DATA_URL;
237+
// messageBody = JSON.stringify({
238+
// token: token,
239+
// mode,
240+
// input_tables: baseTables.map(t => {return { name: t.id.replace(/\.[^/.]+$/ , ""), rows: t.rows }}),
241+
// output_fields: activeBaseFields.map(f => { return {name: f.name } }),
242+
// dialog: triggerTable.derive?.dialog,
243+
// new_instruction: prompt,
244+
// model: activeModel
245+
// })
246+
//engine = getUrls().SERVER_REFINE_DATA_URL;
236247
}
237248

238249
let message = {

0 commit comments

Comments
 (0)