|
16 | 16 | import json |
17 | 17 | import os |
18 | 18 | from datetime import datetime |
| 19 | +import re |
19 | 20 |
|
20 | 21 | from text_2_sql_core.payloads.interaction_payloads import ( |
21 | 22 | QuestionPayload, |
@@ -108,49 +109,69 @@ def extract_disambiguation_request( |
108 | 109 | self, messages: list |
109 | 110 | ) -> DismabiguationRequestPayload: |
110 | 111 | """Extract the disambiguation request from the answer.""" |
111 | | - |
112 | 112 | disambiguation_request = messages[-1].content |
113 | | - |
114 | | - # TODO: Properly extract the disambiguation request |
115 | 113 | return DismabiguationRequestPayload( |
116 | 114 | disambiguation_request=disambiguation_request, |
117 | 115 | ) |
118 | 116 |
|
| 117 | + def parse_message_content(self, content): |
| 118 | + """Parse different message content formats into a dictionary.""" |
| 119 | + if isinstance(content, (list, dict)): |
| 120 | + # If it's already a list or dict, convert to JSON string |
| 121 | + return json.dumps(content) |
| 122 | + |
| 123 | + # Try to extract JSON from markdown-style code blocks |
| 124 | + json_match = re.search(r"```json\s*(.*?)\s*```", content, re.DOTALL) |
| 125 | + if json_match: |
| 126 | + try: |
| 127 | + return json.loads(json_match.group(1)) |
| 128 | + except json.JSONDecodeError: |
| 129 | + pass |
| 130 | + |
| 131 | + # Try parsing as regular JSON |
| 132 | + try: |
| 133 | + return json.loads(content) |
| 134 | + except json.JSONDecodeError: |
| 135 | + pass |
| 136 | + |
| 137 | + # If all parsing attempts fail, return the content as-is |
| 138 | + return content |
| 139 | + |
119 | 140 | def extract_sources(self, messages: list) -> AnswerWithSourcesPayload: |
120 | 141 | """Extract the sources from the answer.""" |
121 | | - |
122 | 142 | answer = messages[-1].content |
123 | | - |
124 | | - sql_query_results = messages[-2].content |
| 143 | + sql_query_results = self.parse_message_content(messages[-2].content) |
125 | 144 |
|
126 | 145 | try: |
127 | | - sql_query_results = json.loads(sql_query_results) |
| 146 | + if isinstance(sql_query_results, str): |
| 147 | + sql_query_results = json.loads(sql_query_results) |
128 | 148 |
|
129 | 149 | logging.info("SQL Query Results: %s", sql_query_results) |
130 | | - |
131 | 150 | payload = AnswerWithSourcesPayload(answer=answer) |
132 | 151 |
|
133 | | - for question, sql_query_result_list in sql_query_results["results"].items(): |
134 | | - logging.info( |
135 | | - "SQL Query Result for question '%s': %s", |
136 | | - question, |
137 | | - sql_query_result_list, |
138 | | - ) |
139 | | - |
140 | | - for sql_query_result in sql_query_result_list: |
141 | | - logging.info("SQL Query Result: %s", sql_query_result) |
142 | | - # Instantiate Source and append to the payload's sources list |
143 | | - source = AnswerWithSourcesPayload.Body.Source( |
144 | | - sql_query=sql_query_result["sql_query"], |
145 | | - sql_rows=sql_query_result["sql_rows"], |
| 152 | + if isinstance(sql_query_results, dict) and "results" in sql_query_results: |
| 153 | + for question, sql_query_result_list in sql_query_results[ |
| 154 | + "results" |
| 155 | + ].items(): |
| 156 | + logging.info( |
| 157 | + "SQL Query Result for question '%s': %s", |
| 158 | + question, |
| 159 | + sql_query_result_list, |
146 | 160 | ) |
147 | | - payload.body.sources.append(source) |
| 161 | + |
| 162 | + for sql_query_result in sql_query_result_list: |
| 163 | + logging.info("SQL Query Result: %s", sql_query_result) |
| 164 | + source = AnswerWithSourcesPayload.Body.Source( |
| 165 | + sql_query=sql_query_result["sql_query"], |
| 166 | + sql_rows=sql_query_result["sql_rows"], |
| 167 | + ) |
| 168 | + payload.body.sources.append(source) |
148 | 169 |
|
149 | 170 | return payload |
150 | 171 |
|
151 | | - except json.JSONDecodeError: |
152 | | - logging.error("Could not load message: %s", sql_query_results) |
153 | | - raise ValueError("Could not load message") |
| 172 | + except Exception as e: |
| 173 | + logging.error("Error processing results: %s", str(e)) |
| 174 | + return AnswerWithSourcesPayload(answer=answer) |
154 | 175 |
|
155 | 176 | async def process_question( |
156 | 177 | self, |
|
0 commit comments