Skip to content

Commit 2c2ca4d

Browse files
authored
[deploy] Merge pull request #99 from microsoft/dev
v0.1.6: multi-table data formulation support
2 parents 0e1f215 + 92c2262 commit 2c2ca4d

16 files changed

+1122
-615
lines changed

py-src/data_formulator/agents/agent_data_rec.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
(4) "visualization_fields" should be no more than 3 (for x,y,legend).
5353
(5) "chart_type" must be one of "point", "bar", "line", or "boxplot"
5454
55-
2. Then, write a python function based on the inferred goal, the function input is a dataframe "df" and the output is the transformed dataframe "transformed_df". "transformed_df" should contain all "output_fields" from the refined goal.
55+
2. Then, write a python function based on the inferred goal, the function input is a dataframe "df" (or multiple dataframes based on tables presented in the [CONTEXT] section) and the output is the transformed dataframe "transformed_df". "transformed_df" should contain all "output_fields" from the refined goal.
5656
The python function must follow the template provided in [TEMPLATE], do not import any other libraries or modify function name. The function should be as simple as possible and easily readable.
5757
If there is no data transformation needed based on "output_fields", the transformation function can simply "return df".
5858
@@ -63,11 +63,15 @@
6363
import collections
6464
import numpy as np
6565
66-
def transform_data(df):
66+
def transform_data(df1, df2, ...):
6767
# complete the template here
6868
return transformed_df
6969
```
7070
71+
note:
72+
- if the user provided one table, then it should be def transform_data(df1), if the user provided multiple tables, then it should be def transform_data(df1, df2, ...) and you should consider the join between tables to derive the output.
73+
- try to use table names to refer to the input dataframes, for example, if the user provided two tables city and weather, you can use `transform_data(df_city, df_weather)` to refer to the two dataframes.
74+
7175
3. The [OUTPUT] must only contain a json object representing the refined goal and a python code block representing the transformation code, do not add any extra text explanation.
7276
'''
7377

py-src/data_formulator/agents/agent_data_transform_v2.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
}
4646
```
4747
48-
2. Then, write a python function based on the refined goal, the function input is a dataframe "df" and the output is the transformed dataframe "transformed_df". "transformed_df" should contain all "output_fields" from the refined goal.
48+
2. Then, write a python function based on the refined goal, the function input is a dataframe "df" (or multiple dataframes based on tables presented in the [CONTEXT] section) and the output is the transformed dataframe "transformed_df". "transformed_df" should contain all "output_fields" from the refined goal.
4949
The python function must follow the template provided in [TEMPLATE], do not import any other libraries or modify function name. The function should be as simple as possible and easily readable.
5050
If there is no data transformation needed based on "output_fields", the transformation function can simply "return df".
5151
@@ -56,11 +56,15 @@
5656
import collections
5757
import numpy as np
5858
59-
def transform_data(df):
59+
def transform_data(df1, df2, ...):
6060
# complete the template here
6161
return transformed_df
6262
```
6363
64+
note:
65+
- if the user provided one table, then it should be def transform_data(df1), if the user provided multiple tables, then it should be def transform_data(df1, df2, ...) and you should consider the join between tables to derive the output.
66+
- try to use table names to refer to the input dataframes, for example, if the user provided two tables city and weather, you can use `transform_data(df_city, df_weather)` to refer to the two dataframes.
67+
6468
3. The [OUTPUT] must only contain a json object representing the refined goal (including "detailed_instruction", "output_fields", "visualization_fields" and "reason") and a python code block representing the transformation code, do not add any extra text explanation.
6569
'''
6670

@@ -226,6 +230,10 @@ def process_gpt_response(self, input_tables, messages, response):
226230
if len(code_blocks) > 0:
227231
code_str = code_blocks[-1]
228232

233+
for table in input_tables:
234+
logger.info(f"Table: {table['name']}")
235+
logger.info(table['rows'])
236+
229237
try:
230238
result = py_sandbox.run_transform_in_sandbox2020(code_str, [t['rows'] for t in input_tables])
231239
result['code'] = code_str
@@ -254,7 +262,16 @@ def process_gpt_response(self, input_tables, messages, response):
254262
return candidates
255263

256264

257-
def run(self, input_tables, description, expected_fields: list[str], n=1):
265+
def run(self, input_tables, description, expected_fields: list[str], prev_messages: list[dict] = [], n=1):
266+
267+
if len(prev_messages) > 0:
268+
logger.info("=== Previous messages ===>")
269+
formatted_prev_messages = ""
270+
for m in prev_messages:
271+
if m['role'] != 'system':
272+
formatted_prev_messages += f"{m['role']}: \n\n\t{m['content']}\n\n"
273+
logger.info(formatted_prev_messages)
274+
prev_messages = [{"role": "user", "content": '[Previous Messages] Here are the previous messages for your reference:\n\n' + formatted_prev_messages}]
258275

259276
data_summary = generate_data_summary(input_tables, include_data_samples=True)
260277

@@ -268,6 +285,7 @@ def run(self, input_tables, description, expected_fields: list[str], n=1):
268285
logger.info(user_query)
269286

270287
messages = [{"role":"system", "content": self.system_prompt},
288+
*prev_messages,
271289
{"role":"user","content": user_query}]
272290

273291
response = completion_response_wrapper(self.client, messages, n)

py-src/data_formulator/agents/client_utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
2-
from litellm import completion
2+
import litellm
3+
import openai
34
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
45

5-
66
class Client(object):
77
"""
88
Returns a LiteLLM client configured for the specified endpoint and model.
@@ -15,11 +15,17 @@ def __init__(self, endpoint, model, api_key=None, api_base=None, api_version=No
1515

1616
# other params, including temperature, max_completion_tokens, api_base, api_version
1717
self.params = {
18-
"api_key": api_key,
1918
"temperature": 0.7,
2019
"max_completion_tokens": 1200,
2120
}
2221

22+
if api_key is not None and api_key != "":
23+
self.params["api_key"] = api_key
24+
if api_base is not None and api_base != "":
25+
self.params["api_base"] = api_base
26+
if api_version is not None and api_version != "":
27+
self.params["api_version"] = api_version
28+
2329
if self.endpoint == "gemini":
2430
if model.startswith("gemini/"):
2531
self.model = model
@@ -53,9 +59,24 @@ def get_completion(self, messages):
5359
Supports OpenAI, Azure, Ollama, and other providers via LiteLLM.
5460
"""
5561
# Configure LiteLLM
56-
return completion(
57-
model=self.model,
58-
messages=messages,
59-
drop_params=True,
60-
**self.params
61-
)
62+
63+
if self.endpoint == "openai":
64+
client = openai.OpenAI(
65+
api_key=self.params["api_key"],
66+
base_url=self.params["api_base"] if "api_base" in self.params else None,
67+
timeout=120
68+
)
69+
70+
return client.chat.completions.create(
71+
model=self.model,
72+
messages=messages,
73+
temperature=self.params["temperature"],
74+
max_tokens=self.params["max_completion_tokens"],
75+
)
76+
else:
77+
return litellm.completion(
78+
model=self.model,
79+
messages=messages,
80+
drop_params=True,
81+
**self.params
82+
)

py-src/data_formulator/app.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,11 @@ def derive_data():
425425
new_fields = content["new_fields"]
426426
instruction = content["extra_prompt"]
427427

428+
if "additional_messages" in content:
429+
prev_messages = content["additional_messages"]
430+
else:
431+
prev_messages = []
432+
428433
print("spec------------------------------")
429434
print(new_fields)
430435
print(instruction)
@@ -439,7 +444,7 @@ def derive_data():
439444
results = agent.run(input_tables, instruction)
440445
else:
441446
agent = DataTransformationAgentV2(client=client)
442-
results = agent.run(input_tables, instruction, [field['name'] for field in new_fields])
447+
results = agent.run(input_tables, instruction, [field['name'] for field in new_fields], prev_messages)
443448

444449
repair_attempts = 0
445450
while results[0]['status'] == 'error' and repair_attempts == 0: # only try once

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "data_formulator"
7-
version = "0.1.5.1"
7+
version = "0.1.6"
88

99
requires-python = ">=3.9"
1010
authors = [

0 commit comments

Comments
 (0)