diff --git a/genie-tool/genie_tool/tool/table_rag/table_column_filter.py b/genie-tool/genie_tool/tool/table_rag/table_column_filter.py index c499e912d..e80249b11 100644 --- a/genie-tool/genie_tool/tool/table_rag/table_column_filter.py +++ b/genie-tool/genie_tool/tool/table_rag/table_column_filter.py @@ -6,6 +6,7 @@ import textwrap import time import traceback +import json_repair from calendar import day_name from datetime import date @@ -231,8 +232,8 @@ async def _filter_single_table(self, semaphore, table_schema_info: dict) -> dict top_p=0.95, only_content=True): llm_response += chunk - result_dict = json.loads(self._parse_json_result(llm_response)) - + # result_dict = json.loads(self._parse_json_result(llm_response)) + result_dict = json_repair.loads(llm_response) if str(result_dict["relatedFlag"]).lower() == "true": columns = table_schema_info.get("schemaList", []) # column_map = {column["columnId"]: column for column in columns} diff --git a/genie-tool/pyproject.toml b/genie-tool/pyproject.toml index a3c93c538..d89a481fd 100644 --- a/genie-tool/pyproject.toml +++ b/genie-tool/pyproject.toml @@ -33,4 +33,5 @@ dependencies = [ "statsmodels>=0.14.5", "tabulate>=0.9.0", "uvicorn>=0.35.0", + "json_repair>=0.54.0" ]