Skip to content

Commit 6c52210

Browse files
authored
Optimize Genie Parse Result (#120)
1 parent 05f2a9d commit 6c52210

File tree

2 files changed

+112
-9
lines changed

2 files changed

+112
-9
lines changed

src/databricks_ai_bridge/genie.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import bisect
12
import logging
23
import time
34
from dataclasses import dataclass
@@ -59,15 +60,35 @@ def _parse_query_result(resp) -> Union[str, pd.DataFrame]:
5960

6061
rows.append(row)
6162

62-
query_result = pd.DataFrame(rows, columns=header).to_markdown()
63-
63+
dataframe = pd.DataFrame(rows, columns=header)
64+
query_result = dataframe.to_markdown()
6465
tokens_used = _count_tokens(query_result)
65-
while tokens_used > MAX_TOKENS_OF_DATA:
66-
rows.pop()
67-
query_result = pd.DataFrame(rows, columns=header).to_markdown()
68-
tokens_used = _count_tokens(query_result)
6966

70-
return query_result.strip() if query_result else query_result
67+
# If the full result fits, return it
68+
if tokens_used <= MAX_TOKENS_OF_DATA:
69+
return query_result.strip()
70+
71+
def is_too_big(n):
72+
return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
73+
74+
# Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
75+
# Passing True, as this is the target value we are looking for when _is_too_big returns
76+
cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
77+
78+
# Slice to the found limit
79+
truncated_df = dataframe.iloc[:cutoff]
80+
81+
# Edge case: Cannot return any rows because of tokens so return an empty string
82+
if len(truncated_df) == 0:
83+
return ""
84+
85+
truncated_result = truncated_df.to_markdown()
86+
87+
# Double-check edge case if we overshot by one
88+
if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
89+
truncated_result = truncated_df.iloc[:-1].to_markdown()
90+
91+
return truncated_result.strip()
7192

7293

7394
class Genie:

tests/databricks_ai_bridge/test_genie.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from datetime import datetime
1+
import random
2+
from datetime import datetime, timedelta
3+
from io import StringIO
24
from unittest.mock import patch
35

46
import pandas as pd
@@ -186,7 +188,7 @@ def test_parse_query_result_with_null_values():
186188
assert result == expected_df.to_markdown()
187189

188190

189-
def test_parse_query_result_trims_large_data():
191+
def test_parse_query_result_trims_data():
190192
# patch MAX_TOKENS_OF_DATA to 100 for this test
191193
with patch("databricks_ai_bridge.genie.MAX_TOKENS_OF_DATA", 100):
192194
resp = {
@@ -232,6 +234,86 @@ def test_parse_query_result_trims_large_data():
232234
assert _count_tokens(result) <= 100
233235

234236

237+
def markdown_to_dataframe(markdown_str: str) -> pd.DataFrame:
238+
if markdown_str == "":
239+
return pd.DataFrame()
240+
241+
lines = markdown_str.strip().splitlines()
242+
243+
# Remove Markdown separator row (2nd line)
244+
lines = [line.strip().strip("|") for i, line in enumerate(lines) if i != 1]
245+
246+
# Re-join cleaned lines and parse
247+
cleaned_markdown = "\n".join(lines)
248+
df = pd.read_csv(StringIO(cleaned_markdown), sep="|")
249+
250+
# Strip whitespace from column names and values
251+
df.columns = [col.strip() for col in df.columns]
252+
df = df.applymap(lambda x: x.strip() if isinstance(x, str) else x)
253+
254+
# Drop the first column
255+
df = df.drop(columns=[df.columns[0]])
256+
257+
return df
258+
259+
260+
@pytest.mark.parametrize("max_tokens", [1, 100, 1000, 2000, 8000, 10000, 15000, 19000, 100000])
261+
def test_parse_query_result_trims_large_data(max_tokens):
262+
"""
263+
Ensure _parse_query_result trims output to stay within token limits.
264+
"""
265+
with patch("databricks_ai_bridge.genie.MAX_TOKENS_OF_DATA", max_tokens):
266+
base_date = datetime(2023, 1, 1)
267+
names = ["Alice", "Bob", "Charlie", "David", "Eve", "Frank", "Grace", "Hank", "Ivy", "Jack"]
268+
269+
data_array = [
270+
[
271+
str(i + 1),
272+
random.choice(names),
273+
(base_date + timedelta(days=random.randint(0, 365))).strftime("%Y-%m-%dT%H:%M:%SZ"),
274+
]
275+
for i in range(1000)
276+
]
277+
278+
response = {
279+
"manifest": {
280+
"schema": {
281+
"columns": [
282+
{"name": "id", "type_name": "INT"},
283+
{"name": "name", "type_name": "STRING"},
284+
{"name": "created_at", "type_name": "TIMESTAMP"},
285+
]
286+
}
287+
},
288+
"result": {"data_array": data_array},
289+
}
290+
291+
markdown_result = _parse_query_result(response)
292+
result_df = markdown_to_dataframe(markdown_result)
293+
294+
expected_df = pd.DataFrame(
295+
{
296+
"id": [int(row[0]) for row in data_array],
297+
"name": [row[1] for row in data_array],
298+
"created_at": [
299+
datetime.strptime(row[2], "%Y-%m-%dT%H:%M:%SZ").date() for row in data_array
300+
],
301+
}
302+
)
303+
304+
expected_markdown = (
305+
"" if len(result_df) == 0 else expected_df[: len(result_df)].to_markdown()
306+
)
307+
# Ensure result matches expected subset and respects token limit
308+
assert markdown_result == expected_markdown
309+
assert _count_tokens(markdown_result) <= max_tokens
310+
# Ensure adding one more row would exceed token limit or we're at full length
311+
next_row_exceeds = (
312+
_count_tokens(expected_df.iloc[: len(result_df) + 1].to_markdown()) > max_tokens
313+
)
314+
assert len(result_df) == len(expected_df) or next_row_exceeds
315+
316+
235317
def test_poll_query_results_max_iterations(genie, mock_workspace_client):
236318
# patch MAX_ITERATIONS to 2 for this test and sleep to avoid delays
237319
with (

0 commit comments

Comments
 (0)