Skip to content

Commit 5329613

Browse files
committed
Update sql
1 parent d15a9b6 commit 5329613

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,29 @@ async def on_messages_stream(
101101
if schema not in final_schemas:
102102
final_schemas.append(schema)
103103

104-
final_columns = []
105-
for column_value_result in column_value_results:
104+
columns_for_filter = {}
105+
values_for_filter = {}
106+
for filter, column_value_result in zip(
107+
loaded_entity_result["filter_conditions"], column_value_results
108+
):
109+
columns_for_filter[filter] = []
110+
values_for_filter[filter] = []
106111
for column in column_value_result:
107-
if column not in final_columns:
108-
final_columns.append(column)
112+
if column["Column"] not in columns_for_filter[filter]:
113+
columns_for_filter[filter].append(column["Column"])
109114

110-
all_column_lengths = [len(column) for column in final_columns]
115+
if column["Value"] not in values_for_filter[filter]:
116+
values_for_filter[filter].append(column["Value"])
117+
118+
num_all_values = [len(filter) for filter in values_for_filter]
119+
num_all_columns = [len(filter) for filter in columns_for_filter]
111120

112121
final_results = {
113-
"MANDATORY_DISAMBIGUATION": max(all_column_lengths) > 3
114-
or len(final_columns) > 3,
115-
"schemas": final_schemas,
116-
"column_values": final_columns,
122+
"MANDATORY_DISAMBIGUATION": max(num_all_values) > 3
123+
or max(num_all_columns) > 3,
124+
"COLUMN_OPTIONS_FOR_FILTERS": columns_for_filter,
125+
"VALUE_OPTIONS_FOR_FILTERS": values_for_filter,
126+
"SCHEMA_OPTIONS": final_schemas,
117127
}
118128

119129
logging.info(f"Final results: {final_results}")

0 commit comments

Comments
 (0)