Skip to content

Commit bbedda4

Browse files
committed
supporting multiple LLM backend (require a little more testing)
1 parent e529772 commit bbedda4

16 files changed

+385
-274
lines changed

py-src/data_formulator/agents/agent_code_explanation.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ def transform_data(df_0):
6666

6767
class CodeExplanationAgent(object):
6868

69-
def __init__(self, client, model):
69+
def __init__(self, client):
7070
self.client = client
71-
self.model = model
7271

7372
def run(self, input_tables, code):
7473

@@ -82,9 +81,7 @@ def run(self, input_tables, code):
8281
{"role":"user","content": user_query}]
8382

8483
###### the part that calls open_ai
85-
response = self.client.chat.completions.create(
86-
model=self.model, messages = messages, temperature=0.7, max_tokens=1200,
87-
top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None)
84+
response = self.client.get_completion(messages = messages)
8885

8986
logger.info('\n=== explanation output ===>\n')
9087
logger.info(response.choices[0].message.content)

py-src/data_formulator/agents/agent_concept_derive.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,8 @@
167167

168168
class ConceptDeriveAgent(object):
169169

170-
def __init__(self, client, model):
170+
def __init__(self, client):
171171
self.client = client
172-
self.model = model
173172

174173
def run(self, input_table, input_fields, output_field, description, n=1):
175174
"""derive a new concept based on input table, input fields, and output field name, (and description)
@@ -190,9 +189,7 @@ def run(self, input_table, input_fields, output_field, description, n=1):
190189
{"role":"user","content": user_query}]
191190

192191
###### the part that calls open_ai
193-
response = self.client.chat.completions.create(
194-
model=self.model, messages = messages, temperature=0.7, max_tokens=1200,
195-
top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None)
192+
response = self.client.get_completion(messages = messages)
196193

197194
#log = {'messages': messages, 'response': response.model_dump(mode='json')}
198195

py-src/data_formulator/agents/agent_data_clean.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@
7878

7979
class DataCleanAgent(object):
8080

81-
def __init__(self, client, model):
82-
self.model = model
81+
def __init__(self, client):
8382
self.client = client
8483

8584
def run(self, content_type, raw_data, image_cleaning_instruction):
@@ -129,9 +128,7 @@ def run(self, content_type, raw_data, image_cleaning_instruction):
129128
messages = [system_message, user_prompt]
130129

131130
###### the part that calls open_ai
132-
response = self.client.chat.completions.create(
133-
model=self.model, messages = messages, temperature=0.7, max_tokens=1200,
134-
top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None)
131+
response = self.client.get_completion(messages = messages)
135132

136133
candidates = []
137134
for choice in response.choices:

py-src/data_formulator/agents/agent_data_filter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,8 @@ def filter_row(row, df):
125125

126126
class DataFilterAgent(object):
127127

128-
def __init__(self, client, model):
128+
def __init__(self, client):
129129
self.client = client
130-
self.model = model
131130

132131
def process_gpt_result(self, input_table, response, messages):
133132
#log = {'messages': messages, 'response': response.model_dump(mode='json')}
@@ -177,9 +176,7 @@ def run(self, input_table, description):
177176
{"role":"user","content": user_query}]
178177

179178
###### the part that calls open_ai
180-
response = self.client.chat.completions.create(
181-
model=self.model, messages = messages, temperature=0.7, max_tokens=1200,
182-
top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None)
179+
response = self.client.get_completion(messages = messages)
183180

184181
return self.process_gpt_result(input_table, response, messages)
185182

@@ -190,8 +187,6 @@ def followup(self, input_table, dialog, new_instruction: str, n=1):
190187
"content": new_instruction + '\nupdate the filter function accordingly'}]
191188

192189
##### the part that calls open_ai
193-
response = self.client.chat.completions.create(
194-
model=self.model, messages=messages, temperature=0.7, max_tokens=1200,
195-
top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None)
190+
response = self.client.get_completion(messages = messages)
196191

197192
return self.process_gpt_result(input_table, response, messages)

py-src/data_formulator/agents/agent_data_load.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,8 @@
124124

125125
class DataLoadAgent(object):
126126

127-
def __init__(self, client, model):
127+
def __init__(self, client):
128128
self.client = client
129-
self.model = model
130129

131130
def run(self, input_data, n=1):
132131

@@ -140,9 +139,7 @@ def run(self, input_data, n=1):
140139
{"role":"user","content": user_query}]
141140

142141
###### the part that calls open_ai
143-
response = self.client.chat.completions.create(
144-
model=self.model, messages=messages, temperature=0.2, max_tokens=4096,
145-
top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None)
142+
response = self.client.get_completion(messages = messages)
146143

147144
#log = {'messages': messages, 'response': response.model_dump(mode='json')}
148145

py-src/data_formulator/agents/agent_data_rec.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,8 @@ def transform_data(df):
126126

127127
class DataRecAgent(object):
128128

129-
def __init__(self, client, model, system_prompt=None):
129+
def __init__(self, client, system_prompt=None):
130130
self.client = client
131-
self.model = model
132131
self.system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT
133132

134133
def process_gpt_response(self, input_tables, messages, response):
@@ -192,7 +191,7 @@ def run(self, input_tables, description, n=1):
192191
messages = [{"role":"system", "content": self.system_prompt},
193192
{"role":"user","content": user_query}]
194193

195-
response = completion_response_wrapper(self.client, self.model, messages, n)
194+
response = completion_response_wrapper(self.client, messages, n)
196195

197196
return self.process_gpt_response(input_tables, messages, response)
198197

@@ -204,7 +203,6 @@ def followup(self, input_tables, dialog, new_instruction: str, n=1):
204203

205204
messages = [*dialog, {"role":"user", "content": f"Update: \n\n{new_instruction}"}]
206205

207-
##### the part that calls open_ai
208-
response = completion_response_wrapper(self.client, self.model, messages, n)
206+
response = completion_response_wrapper(self.client, messages, n)
209207

210208
return self.process_gpt_response(input_tables, messages, response)

py-src/data_formulator/agents/agent_data_transform_v2.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,10 @@ def transform_data(df):
178178
```
179179
'''
180180

181-
def completion_response_wrapper(client, model, messages, n):
181+
def completion_response_wrapper(client, messages, n):
182182
### wrapper for completion response, especially handling errors
183183
try:
184-
response = client.chat.completions.create(
185-
model=model, messages=messages, temperature=0.7, max_tokens=1200,
186-
top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None)
184+
response = client.get_completion(messages = messages)
187185
except Exception as e:
188186
response = e
189187

@@ -192,9 +190,8 @@ def completion_response_wrapper(client, model, messages, n):
192190

193191
class DataTransformationAgentV2(object):
194192

195-
def __init__(self, client, model, system_prompt=None):
193+
def __init__(self, client, system_prompt=None):
196194
self.client = client
197-
self.model = model
198195
self.system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT
199196

200197
def process_gpt_response(self, input_tables, messages, response):
@@ -265,7 +262,7 @@ def run(self, input_tables, description, expected_fields: list[str], n=1):
265262
messages = [{"role":"system", "content": self.system_prompt},
266263
{"role":"user","content": user_query}]
267264

268-
response = completion_response_wrapper(self.client, self.model, messages, n)
265+
response = completion_response_wrapper(self.client, messages, n)
269266

270267
return self.process_gpt_response(input_tables, messages, response)
271268

@@ -287,6 +284,6 @@ def followup(self, input_tables, dialog, output_fields: list[str], new_instructi
287284
messages = [*updated_dialog, {"role":"user",
288285
"content": f"Update the code above based on the following instruction:\n\n{json.dumps(goal, indent=4)}"}]
289286

290-
response = completion_response_wrapper(self.client, self.model, messages, n)
287+
response = completion_response_wrapper(self.client, messages, n)
291288

292289
return self.process_gpt_response(input_tables, messages, response)

py-src/data_formulator/agents/agent_data_transformation.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,8 @@ def transform_data(df_0):
122122

123123
class DataTransformationAgent(object):
124124

125-
def __init__(self, client, model):
125+
def __init__(self, client):
126126
self.client = client
127-
self.model = model
128127

129128
def process_gpt_response(self, input_tables, messages, response):
130129
"""process gpt response to handle execution"""
@@ -185,9 +184,7 @@ def run(self, input_tables, description, expected_fields: list[str], n=1, enrich
185184
{"role":"user","content": user_query}]
186185

187186
###### the part that calls open_ai
188-
response = self.client.chat.completions.create(
189-
model=self.model, messages = messages, temperature=0.7, max_tokens=1200,
190-
top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None)
187+
response = self.client.get_completion(messages = messages)
191188

192189
return self.process_gpt_response(input_tables, messages, response)
193190

@@ -207,9 +204,7 @@ def followup(self, input_tables, dialog, output_fields: list[str], new_instructi
207204
"content": "Update the code above based on the following instruction:\n\n" + new_instruction + output_fields_instr}]
208205

209206
##### the part that calls open_ai
210-
response = self.client.chat.completions.create(
211-
model=self.model, messages=messages, temperature=0.7, max_tokens=1200,
212-
top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None)
207+
response = self.client.get_completion(messages = messages)
213208

214209
logger.info(response)
215210

py-src/data_formulator/agents/agent_generic_py_concept.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,8 @@ def derive(row, df):
157157

158158
class GenericPyConceptDeriveAgent(object):
159159

160-
def __init__(self, client, model_version):
160+
def __init__(self, client):
161161
self.client = client
162-
self.model_version = model_version
163162

164163
def process_gpt_response(self, input_table, output_field, response, messages):
165164
#log = {'messages': messages, 'response': response.model_dump(mode='json')}
@@ -220,10 +219,7 @@ def run(self, input_table, output_field, description):
220219
{"role":"user","content": user_query}]
221220

222221
###### the part that calls open_ai
223-
response = self.client.chat.completions.create(
224-
model=self.model_version,
225-
messages = messages, temperature=0.7, max_tokens=1200,
226-
top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None)
222+
response = self.client.get_completion(messages = messages)
227223

228224
return self.process_gpt_response(input_table, output_field, response, messages)
229225

@@ -234,9 +230,7 @@ def followup(self, input_table, dialog, output_field: str, new_instruction: str,
234230
"content": new_instruction + '\n update the function accordingly'}]
235231

236232
##### the part that calls open_ai
237-
response = self.client.chat.completions.create(
238-
model=self.model, messages=messages, temperature=0.7, max_tokens=1200,
239-
top_p=0.95, n=n, frequency_penalty=0, presence_penalty=0, stop=None)
233+
response = self.client.get_completion(messages = messages)
240234

241235
candidates = self.process_gpt_response(input_table, output_field, response, messages)
242236

py-src/data_formulator/agents/agent_py_concept_derive.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ def derive(writing, reading, math):
131131

132132
class PyConceptDeriveAgent(object):
133133

134-
def __init__(self, client, model):
135-
self.model = model
134+
def __init__(self, client):
136135
self.client = client
137136

138137
def run(self, input_table, input_fields, output_field, description):
@@ -163,9 +162,7 @@ def derive({arg_string}):
163162
{"role":"user","content": user_query}]
164163

165164
###### the part that calls open_ai
166-
response = self.client.chat.completions.create(
167-
model=self.model, messages = messages, temperature=0.7, max_tokens=1200,
168-
top_p=0.95, n=1, frequency_penalty=0, presence_penalty=0, stop=None)
165+
response = self.client.get_completion(messages = messages)
169166

170167
#log = {'messages': messages, 'response': response.model_dump(mode='json')}
171168

0 commit comments

Comments
 (0)