|
1 | | -from datetime import datetime |
| 1 | +import random |
| 2 | +from datetime import datetime, timedelta |
| 3 | +from io import StringIO |
2 | 4 | from unittest.mock import patch |
3 | 5 |
|
4 | 6 | import pandas as pd |
@@ -186,7 +188,7 @@ def test_parse_query_result_with_null_values(): |
186 | 188 | assert result == expected_df.to_markdown() |
187 | 189 |
|
188 | 190 |
|
189 | | -def test_parse_query_result_trims_large_data(): |
| 191 | +def test_parse_query_result_trims_data(): |
190 | 192 | # patch MAX_TOKENS_OF_DATA to 100 for this test |
191 | 193 | with patch("databricks_ai_bridge.genie.MAX_TOKENS_OF_DATA", 100): |
192 | 194 | resp = { |
@@ -232,6 +234,86 @@ def test_parse_query_result_trims_large_data(): |
232 | 234 | assert _count_tokens(result) <= 100 |
233 | 235 |
|
234 | 236 |
|
| 237 | +def markdown_to_dataframe(markdown_str: str) -> pd.DataFrame: |
| 238 | + if markdown_str == "": |
| 239 | + return pd.DataFrame() |
| 240 | + |
| 241 | + lines = markdown_str.strip().splitlines() |
| 242 | + |
| 243 | + # Remove Markdown separator row (2nd line) |
| 244 | + lines = [line.strip().strip("|") for i, line in enumerate(lines) if i != 1] |
| 245 | + |
| 246 | + # Re-join cleaned lines and parse |
| 247 | + cleaned_markdown = "\n".join(lines) |
| 248 | + df = pd.read_csv(StringIO(cleaned_markdown), sep="|") |
| 249 | + |
| 250 | + # Strip whitespace from column names and values |
| 251 | + df.columns = [col.strip() for col in df.columns] |
| 252 | + df = df.applymap(lambda x: x.strip() if isinstance(x, str) else x) |
| 253 | + |
| 254 | + # Drop the first column |
| 255 | + df = df.drop(columns=[df.columns[0]]) |
| 256 | + |
| 257 | + return df |
| 258 | + |
| 259 | + |
| 260 | +@pytest.mark.parametrize("max_tokens", [1, 100, 1000, 2000, 8000, 10000, 15000, 19000, 100000]) |
| 261 | +def test_parse_query_result_trims_large_data(max_tokens): |
| 262 | + """ |
| 263 | + Ensure _parse_query_result trims output to stay within token limits. |
| 264 | + """ |
| 265 | + with patch("databricks_ai_bridge.genie.MAX_TOKENS_OF_DATA", max_tokens): |
| 266 | + base_date = datetime(2023, 1, 1) |
| 267 | + names = ["Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Hank", "Ivy", "Jack"] |
| 268 | + |
| 269 | + data_array = [ |
| 270 | + [ |
| 271 | + str(i + 1), |
| 272 | + random.choice(names), |
| 273 | + (base_date + timedelta(days=random.randint(0, 365))).strftime("%Y-%m-%dT%H:%M:%SZ"), |
| 274 | + ] |
| 275 | + for i in range(1000) |
| 276 | + ] |
| 277 | + |
| 278 | + response = { |
| 279 | + "manifest": { |
| 280 | + "schema": { |
| 281 | + "columns": [ |
| 282 | + {"name": "id", "type_name": "INT"}, |
| 283 | + {"name": "name", "type_name": "STRING"}, |
| 284 | + {"name": "created_at", "type_name": "TIMESTAMP"}, |
| 285 | + ] |
| 286 | + } |
| 287 | + }, |
| 288 | + "result": {"data_array": data_array}, |
| 289 | + } |
| 290 | + |
| 291 | + markdown_result = _parse_query_result(response) |
| 292 | + result_df = markdown_to_dataframe(markdown_result) |
| 293 | + |
| 294 | + expected_df = pd.DataFrame( |
| 295 | + { |
| 296 | + "id": [int(row[0]) for row in data_array], |
| 297 | + "name": [row[1] for row in data_array], |
| 298 | + "created_at": [ |
| 299 | + datetime.strptime(row[2], "%Y-%m-%dT%H:%M:%SZ").date() for row in data_array |
| 300 | + ], |
| 301 | + } |
| 302 | + ) |
| 303 | + |
| 304 | + expected_markdown = ( |
| 305 | + "" if len(result_df) == 0 else expected_df[: len(result_df)].to_markdown() |
| 306 | + ) |
| 307 | + # Ensure result matches expected subset and respects token limit |
| 308 | + assert markdown_result == expected_markdown |
| 309 | + assert _count_tokens(markdown_result) <= max_tokens |
| 310 | + # Ensure adding one more row would exceed token limit or we're at full length |
| 311 | + next_row_exceeds = ( |
| 312 | + _count_tokens(expected_df.iloc[: len(result_df) + 1].to_markdown()) > max_tokens |
| 313 | + ) |
| 314 | + assert len(result_df) == len(expected_df) or next_row_exceeds |
| 315 | + |
| 316 | + |
235 | 317 | def test_poll_query_results_max_iterations(genie, mock_workspace_client): |
236 | 318 | # patch MAX_ITERATIONS to 2 for this test and sleep to avoid delays |
237 | 319 | with ( |
|
0 commit comments