Skip to content

Commit 9d6d9dd

Browse files
authored
chore: improve error messages for semantic operators (#1078)
* chore: improve error messages for semantic operators * fix tests
1 parent 2d16f6d commit 9d6d9dd

File tree

2 files changed

+114
-18
lines changed

2 files changed

+114
-18
lines changed

bigframes/operations/semantics.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ def agg(
104104
for column in columns:
105105
if column not in self._df.columns:
106106
raise ValueError(f"Column {column} not found.")
107+
if self._df[column].dtype != dtypes.STRING_DTYPE:
108+
raise TypeError(
109+
"Semantics aggregated column must be a string type, not "
110+
f"{type(self._df[column])}"
111+
)
112+
107113
if len(columns) > 1:
108114
raise NotImplementedError(
109115
"Semantic aggregations are limited to a single column."
@@ -324,6 +330,11 @@ def filter(self, instruction: str, model):
324330
for column in columns:
325331
if column not in self._df.columns:
326332
raise ValueError(f"Column {column} not found.")
333+
if self._df[column].dtype != dtypes.STRING_DTYPE:
334+
raise TypeError(
335+
"Semantics aggregated column must be a string type, not "
336+
f"{type(self._df[column])}"
337+
)
327338

328339
user_instruction = self._format_instruction(instruction, columns)
329340
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
@@ -372,7 +383,7 @@ def map(self, instruction: str, output_column: str, model):
372383
in the instructions like:
373384
"Get the ingredients of {food}."
374385
375-
result_column_name:
386+
output_column:
376387
The column name of the mapping result.
377388
378389
model:
@@ -391,6 +402,11 @@ def map(self, instruction: str, output_column: str, model):
391402
for column in columns:
392403
if column not in self._df.columns:
393404
raise ValueError(f"Column {column} not found.")
405+
if self._df[column].dtype != dtypes.STRING_DTYPE:
406+
raise TypeError(
407+
"Semantics aggregated column must be a string type, not "
408+
f"{type(self._df[column])}"
409+
)
394410

395411
user_instruction = self._format_instruction(instruction, columns)
396412
output_instruction = (
@@ -512,8 +528,11 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
512528
else:
513529
raise ValueError(f"Column {col} not found")
514530

515-
if not left_columns or not right_columns:
516-
raise ValueError()
531+
if not left_columns:
532+
raise ValueError("No left column references.")
533+
534+
if not right_columns:
535+
raise ValueError("No right column references.")
517536

518537
joined_df = self._df.merge(other, how="cross", suffixes=("_left", "_right"))
519538

@@ -570,13 +589,16 @@ def search(
570589
"""
571590

572591
if search_column not in self._df.columns:
573-
raise ValueError(f"Column {search_column} not found")
592+
raise ValueError(f"Column `{search_column}` not found")
574593

575594
import bigframes.ml.llm as llm
576595

577596
if not isinstance(model, llm.TextEmbeddingGenerator):
578597
raise TypeError(f"Expect a text embedding model, but got: {type(model)}")
579598

599+
if top_k < 1:
600+
raise ValueError("top_k must be an integer greater than or equal to 1.")
601+
580602
embedded_df = model.predict(self._df[search_column])
581603
embedded_table = embedded_df.reset_index().to_gbq()
582604

@@ -855,6 +877,9 @@ def sim_join(
855877
f"Number of rows that need processing is {joined_table_rows}, which exceeds row limit {max_rows}."
856878
)
857879

880+
if top_k < 1:
881+
raise ValueError("top_k must be an integer greater than or equal to 1.")
882+
858883
base_table_embedding_column = guid.generate_guid()
859884
base_table = self._attach_embedding(
860885
other, right_on, base_table_embedding_column, model
@@ -926,4 +951,4 @@ def _validate_model(model):
926951
from bigframes.ml.llm import GeminiTextGenerator
927952

928953
if not isinstance(model, GeminiTextGenerator):
929-
raise ValueError("Model is not GeminiText Generator")
954+
raise TypeError("Model is not GeminiText Generator")

tests/system/large/operations/test_semantics.py

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,21 +82,33 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
8282
marks=pytest.mark.xfail(raises=ValueError),
8383
),
8484
pytest.param(
85-
"{city} is in the {non_existing_column}",
85+
"{Movies} is good",
8686
id="non_existing_column",
8787
marks=pytest.mark.xfail(raises=ValueError),
8888
),
8989
pytest.param(
90-
"{city} is in the {country}",
90+
"{Movies} is better than {Movies}",
9191
id="two_columns",
9292
marks=pytest.mark.xfail(raises=NotImplementedError),
9393
),
94+
pytest.param(
95+
"{Year}",
96+
id="invalid_type",
97+
marks=pytest.mark.xfail(raises=TypeError),
98+
),
9499
],
95100
)
96101
def test_agg_invalid_instruction_raise_error(instruction, gemini_flash_model):
97102
bigframes.options.experiments.semantic_operators = True
98103
df = dataframe.DataFrame(
99-
{"country": ["USA", "Germany"], "city": ["Seattle", "Berlin"]}
104+
data={
105+
"Movies": [
106+
"Titanic",
107+
"The Wolf of Wall Street",
108+
"Killers of the Flower Moon",
109+
],
110+
"Year": [1997, 2013, 2023],
111+
},
100112
)
101113
df.semantics.agg(instruction, gemini_flash_model)
102114

@@ -229,15 +241,26 @@ def test_filter_single_column_reference(session, gemini_flash_model):
229241
@pytest.mark.parametrize(
230242
"instruction",
231243
[
232-
"No column reference",
233-
"{city} is in the {non_existing_column}",
244+
pytest.param(
245+
"No column reference",
246+
id="zero_column",
247+
marks=pytest.mark.xfail(raises=ValueError),
248+
),
249+
pytest.param(
250+
"{city} is in the {non_existing_column}",
251+
id="non_existing_column",
252+
marks=pytest.mark.xfail(raises=ValueError),
253+
),
254+
pytest.param(
255+
"{id}",
256+
id="invalid_type",
257+
marks=pytest.mark.xfail(raises=TypeError),
258+
),
234259
],
235260
)
236261
def test_filter_invalid_instruction_raise_error(instruction, gemini_flash_model):
237262
bigframes.options.experiments.semantic_operators = True
238-
df = dataframe.DataFrame(
239-
{"country": ["USA", "Germany"], "city": ["Seattle", "Berlin"]}
240-
)
263+
df = dataframe.DataFrame({"id": [1, 2], "city": ["Seattle", "Berlin"]})
241264

242265
with pytest.raises(ValueError):
243266
df.semantics.filter(instruction, gemini_flash_model)
@@ -249,7 +272,7 @@ def test_filter_invalid_model_raise_error():
249272
{"country": ["USA", "Germany"], "city": ["Seattle", "Berlin"]}
250273
)
251274

252-
with pytest.raises(ValueError):
275+
with pytest.raises(TypeError):
253276
df.semantics.filter("{city} is the capital of {country}", None)
254277

255278

@@ -290,14 +313,28 @@ def test_map(session, gemini_flash_model):
290313
@pytest.mark.parametrize(
291314
"instruction",
292315
[
293-
"No column reference",
294-
"What is the food made from {ingredient_1} and {non_existing_column}?}",
316+
pytest.param(
317+
"No column reference",
318+
id="zero_column",
319+
marks=pytest.mark.xfail(raises=ValueError),
320+
),
321+
pytest.param(
322+
"What is the food made from {ingredient_1} and {non_existing_column}?}",
323+
id="non_existing_column",
324+
marks=pytest.mark.xfail(raises=ValueError),
325+
),
326+
pytest.param(
327+
"{id}",
328+
id="invalid_type",
329+
marks=pytest.mark.xfail(raises=TypeError),
330+
),
295331
],
296332
)
297333
def test_map_invalid_instruction_raise_error(instruction, gemini_flash_model):
298334
bigframes.options.experiments.semantic_operators = True
299335
df = dataframe.DataFrame(
300336
data={
337+
"id": [1, 2],
301338
"ingredient_1": ["Burger Bun", "Soy Bean"],
302339
"ingredient_2": ["Beef Patty", "Bittern"],
303340
}
@@ -316,7 +353,7 @@ def test_map_invalid_model_raise_error():
316353
},
317354
)
318355

319-
with pytest.raises(ValueError):
356+
with pytest.raises(TypeError):
320357
df.semantics.map(
321358
"What is the food made from {ingredient_1} and {ingredient_2}? One word only.",
322359
"food",
@@ -462,7 +499,7 @@ def test_join_invalid_model_raise_error():
462499
cities = dataframe.DataFrame({"city": ["Seattle", "Berlin"]})
463500
countries = dataframe.DataFrame({"country": ["USA", "UK", "Germany"]})
464501

465-
with pytest.raises(ValueError):
502+
with pytest.raises(TypeError):
466503
cities.semantics.join(countries, "{city} is in {country}", None)
467504

468505

@@ -528,6 +565,19 @@ def test_search_invalid_model_raises_error(session):
528565
df.semantics.search("creatures", "monkey", top_k=2, model=None)
529566

530567

568+
def test_search_invalid_top_k_raises_error(session, text_embedding_generator):
569+
bigframes.options.experiments.semantic_operators = True
570+
df = dataframe.DataFrame(
571+
data={"creatures": ["salmon", "sea urchin", "baboons", "frog", "chimpanzee"]},
572+
session=session,
573+
)
574+
575+
with pytest.raises(ValueError):
576+
df.semantics.search(
577+
"creatures", "monkey", top_k=0, model=text_embedding_generator
578+
)
579+
580+
531581
@pytest.mark.parametrize(
532582
"score_column",
533583
[
@@ -614,6 +664,27 @@ def test_sim_join_invalid_model_raises_error(session):
614664
)
615665

616666

667+
def test_sim_join_invalid_top_k_raises_error(session, text_embedding_generator):
668+
bigframes.options.experiments.semantic_operators = True
669+
df1 = dataframe.DataFrame(
670+
data={"creatures": ["salmon", "cat"]},
671+
session=session,
672+
)
673+
df2 = dataframe.DataFrame(
674+
data={"creatures": ["dog", "tuna"]},
675+
session=session,
676+
)
677+
678+
with pytest.raises(ValueError):
679+
df1.semantics.sim_join(
680+
df2,
681+
left_on="creatures",
682+
right_on="creatures",
683+
top_k=0,
684+
model=text_embedding_generator,
685+
)
686+
687+
617688
def test_sim_join_data_too_large_raises_error(session, text_embedding_generator):
618689
bigframes.options.experiments.semantic_operators = True
619690
df1 = dataframe.DataFrame(

0 commit comments

Comments
 (0)