Skip to content

Commit 0d7a8c6

Browse files
authored
Deprecate Genie Truncation (#171)
1 parent 40e3cf3 commit 0d7a8c6

File tree

2 files changed

+66
-28
lines changed

2 files changed

+66
-28
lines changed

src/databricks_ai_bridge/genie.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import bisect
22
import logging
33
import time
4-
import warnings
54
from dataclasses import dataclass
65
from datetime import datetime
76
from typing import Optional, Union
87

98
import mlflow
109
import pandas as pd
11-
import tiktoken
1210
from databricks.sdk import WorkspaceClient
1311

1412
MAX_TOKENS_OF_DATA = 20000
@@ -17,6 +15,8 @@
1715

1816
# Define a function to count tokens
1917
def _count_tokens(text):
18+
import tiktoken
19+
2020
encoding = tiktoken.encoding_for_model("gpt-4o")
2121
return len(encoding.encode(text))
2222

@@ -62,6 +62,16 @@ def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
6262
rows.append(row)
6363

6464
dataframe = pd.DataFrame(rows, columns=header)
65+
66+
if truncate_results:
67+
query_result = _truncate_result(dataframe)
68+
else:
69+
query_result = dataframe.to_markdown()
70+
71+
return query_result.strip()
72+
73+
74+
def _truncate_result(dataframe):
6575
query_result = dataframe.to_markdown()
6676
tokens_used = _count_tokens(query_result)
6777

@@ -88,15 +98,7 @@ def is_too_big(n):
8898
# Double-check edge case if we overshot by one
8999
if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
90100
truncated_result = truncated_df.iloc[:-1].to_markdown()
91-
92-
if not truncate_results:
93-
warnings.warn(
94-
"Detected large Genie output, truncating it to better fit in LLM context windows. Automatic result truncation in Genie is deprecated and will be disabled by default in a future release; we recommend truncating large Genie results in your agent code instead, if needed. To keep automatic truncation for large Genie outputs enabled, set truncate_results=True when initializing the Genie class.",
95-
DeprecationWarning,
96-
stacklevel=2,
97-
)
98-
99-
return truncated_result.strip()
101+
return truncated_result
100102

101103

102104
class Genie:

tests/databricks_ai_bridge/test_genie.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ def test_parse_query_result_with_null_values():
188188
assert result == expected_df.to_markdown()
189189

190190

191-
def test_parse_query_result_trims_data():
191+
@pytest.mark.parametrize("truncate_results", [True, False])
192+
def test_parse_query_result_trims_data(truncate_results):
192193
# patch MAX_TOKENS_OF_DATA to 100 for this test
193194
with patch("databricks_ai_bridge.genie.MAX_TOKENS_OF_DATA", 100):
194195
resp = {
@@ -216,22 +217,57 @@ def test_parse_query_result_trims_data():
216217
]
217218
},
218219
}
219-
result = _parse_query_result(resp, truncate_results=True)
220-
assert (
221-
result
222-
== pd.DataFrame(
223-
{
224-
"id": [1, 2, 3],
225-
"name": ["Alice", "Bob", "Charlie"],
226-
"created_at": [
227-
datetime(2023, 10, 1).date(),
228-
datetime(2023, 10, 2).date(),
229-
datetime(2023, 10, 3).date(),
230-
],
231-
}
232-
).to_markdown()
233-
)
234-
assert _count_tokens(result) <= 100
220+
result = _parse_query_result(resp, truncate_results=truncate_results)
221+
222+
if truncate_results:
223+
assert (
224+
result
225+
== pd.DataFrame(
226+
{
227+
"id": [1, 2, 3],
228+
"name": ["Alice", "Bob", "Charlie"],
229+
"created_at": [
230+
datetime(2023, 10, 1).date(),
231+
datetime(2023, 10, 2).date(),
232+
datetime(2023, 10, 3).date(),
233+
],
234+
}
235+
).to_markdown()
236+
)
237+
assert _count_tokens(result) <= 100
238+
else:
239+
assert (
240+
result
241+
== pd.DataFrame(
242+
{
243+
"id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
244+
"name": [
245+
"Alice",
246+
"Bob",
247+
"Charlie",
248+
"David",
249+
"Eve",
250+
"Frank",
251+
"Grace",
252+
"Hank",
253+
"Ivy",
254+
"Jack",
255+
],
256+
"created_at": [
257+
datetime(2023, 10, 1).date(),
258+
datetime(2023, 10, 2).date(),
259+
datetime(2023, 10, 3).date(),
260+
datetime(2023, 10, 4).date(),
261+
datetime(2023, 10, 5).date(),
262+
datetime(2023, 10, 6).date(),
263+
datetime(2023, 10, 7).date(),
264+
datetime(2023, 10, 8).date(),
265+
datetime(2023, 10, 9).date(),
266+
datetime(2023, 10, 10).date(),
267+
],
268+
}
269+
).to_markdown()
270+
)
235271

236272

237273
def markdown_to_dataframe(markdown_str: str) -> pd.DataFrame:

0 commit comments

Comments
 (0)