Skip to content

Commit 2a0ffac

Browse files
authored
chore: Semantic operations - support non-string types, fix flaky top_k doctests (#1099)
1 parent 9aff171 commit 2a0ffac

File tree

2 files changed

+80
-62
lines changed

2 files changed

+80
-62
lines changed

bigframes/operations/semantics.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,20 @@ def agg(
9898
ValueError: when the instruction refers to a non-existing column, or when
9999
more than one columns are referred to.
100100
"""
101-
self._validate_model(model)
101+
import bigframes.bigquery as bbq
102+
import bigframes.dataframe
103+
import bigframes.series
102104

105+
self._validate_model(model)
103106
columns = self._parse_columns(instruction)
107+
108+
df: bigframes.dataframe.DataFrame = self._df.copy()
104109
for column in columns:
105110
if column not in self._df.columns:
106111
raise ValueError(f"Column {column} not found.")
107-
if self._df[column].dtype != dtypes.STRING_DTYPE:
108-
raise TypeError(
109-
"Semantics aggregated column must be a string type, not "
110-
f"{type(self._df[column])}"
111-
)
112+
113+
if df[column].dtype != dtypes.STRING_DTYPE:
114+
df[column] = df[column].astype(dtypes.STRING_DTYPE)
112115

113116
if len(columns) > 1:
114117
raise NotImplementedError(
@@ -122,11 +125,6 @@ def agg(
122125
"It must be greater than 1."
123126
)
124127

125-
import bigframes.bigquery as bbq
126-
import bigframes.dataframe
127-
import bigframes.series
128-
129-
df: bigframes.dataframe.DataFrame = self._df.copy()
130128
user_instruction = self._format_instruction(instruction, columns)
131129

132130
num_cluster = 1
@@ -325,26 +323,27 @@ def filter(self, instruction: str, model):
325323
ValueError: when the instruction refers to a non-existing column, or when no
326324
columns are referred to.
327325
"""
326+
import bigframes.dataframe
327+
import bigframes.series
328+
328329
self._validate_model(model)
329330
columns = self._parse_columns(instruction)
330331
for column in columns:
331332
if column not in self._df.columns:
332333
raise ValueError(f"Column {column} not found.")
333-
if self._df[column].dtype != dtypes.STRING_DTYPE:
334-
raise TypeError(
335-
"Semantics aggregated column must be a string type, not "
336-
f"{type(self._df[column])}"
337-
)
334+
335+
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
336+
for column in columns:
337+
if df[column].dtype != dtypes.STRING_DTYPE:
338+
df[column] = df[column].astype(dtypes.STRING_DTYPE)
338339

339340
user_instruction = self._format_instruction(instruction, columns)
340341
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
341342

342-
from bigframes.dataframe import DataFrame
343-
344343
results = typing.cast(
345-
DataFrame,
344+
bigframes.dataframe.DataFrame,
346345
model.predict(
347-
self._make_prompt(columns, user_instruction, output_instruction),
346+
self._make_prompt(df, columns, user_instruction, output_instruction),
348347
temperature=0.0,
349348
),
350349
)
@@ -398,28 +397,29 @@ def map(self, instruction: str, output_column: str, model):
398397
ValueError: when the instruction refers to a non-existing column, or when no
399398
columns are referred to.
400399
"""
400+
import bigframes.dataframe
401+
import bigframes.series
402+
401403
self._validate_model(model)
402404
columns = self._parse_columns(instruction)
403405
for column in columns:
404406
if column not in self._df.columns:
405407
raise ValueError(f"Column {column} not found.")
406-
if self._df[column].dtype != dtypes.STRING_DTYPE:
407-
raise TypeError(
408-
"Semantics aggregated column must be a string type, not "
409-
f"{type(self._df[column])}"
410-
)
408+
409+
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
410+
for column in columns:
411+
if df[column].dtype != dtypes.STRING_DTYPE:
412+
df[column] = df[column].astype(dtypes.STRING_DTYPE)
411413

412414
user_instruction = self._format_instruction(instruction, columns)
413415
output_instruction = (
414416
"Based on the provided contenxt, answer the following instruction:"
415417
)
416418

417-
from bigframes.series import Series
418-
419419
results = typing.cast(
420-
Series,
420+
bigframes.series.Series,
421421
model.predict(
422-
self._make_prompt(columns, user_instruction, output_instruction),
422+
self._make_prompt(df, columns, user_instruction, output_instruction),
423423
temperature=0.0,
424424
)["ml_generate_text_llm_result"],
425425
)
@@ -683,6 +683,9 @@ def top_k(self, instruction: str, model, k=10):
683683
ValueError: when the instruction refers to a non-existing column, or when no
684684
columns are referred to.
685685
"""
686+
import bigframes.dataframe
687+
import bigframes.series
688+
686689
self._validate_model(model)
687690
columns = self._parse_columns(instruction)
688691
for column in columns:
@@ -692,12 +695,12 @@ def top_k(self, instruction: str, model, k=10):
692695
raise NotImplementedError(
693696
"Semantic aggregations are limited to a single column."
694697
)
698+
699+
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
695700
column = columns[0]
696-
if self._df[column].dtype != dtypes.STRING_DTYPE:
697-
raise TypeError(
698-
"Referred column must be a string type, not "
699-
f"{type(self._df[column])}"
700-
)
701+
if df[column].dtype != dtypes.STRING_DTYPE:
702+
df[column] = df[column].astype(dtypes.STRING_DTYPE)
703+
701704
# `index` is reserved for the `reset_index` below.
702705
if column == "index":
703706
raise ValueError(
@@ -709,12 +712,7 @@ def top_k(self, instruction: str, model, k=10):
709712

710713
user_instruction = self._format_instruction(instruction, columns)
711714

712-
import bigframes.dataframe
713-
import bigframes.series
714-
715-
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
716715
n = df.shape[0]
717-
718716
if k >= n:
719717
return df
720718

@@ -762,17 +760,17 @@ def _topk_partition(
762760

763761
# Random pivot selection for improved average quickselect performance.
764762
pending_df = df[df[status_column].isna()]
765-
pivot_iloc = np.random.randint(0, pending_df.shape[0] - 1)
763+
pivot_iloc = np.random.randint(0, pending_df.shape[0])
766764
pivot_index = pending_df.iloc[pivot_iloc]["index"]
767765
pivot_df = pending_df[pending_df["index"] == pivot_index]
768766

769767
# Build a prompt to compare the pivot item's relevance to other pending items.
770768
prompt_s = pending_df[pending_df["index"] != pivot_index][column]
771769
prompt_s = (
772770
f"{output_instruction}\n\nQuestion: {user_instruction}\n"
773-
+ "\nDocument 1: "
771+
+ f"\nDocument 1: {column} "
774772
+ pivot_df.iloc[0][column]
775-
+ "\nDocument 2: "
773+
+ f"\nDocument 2: {column} "
776774
+ prompt_s # type:ignore
777775
)
778776

@@ -920,9 +918,8 @@ def _attach_embedding(dataframe, source_column: str, embedding_column: str, mode
920918
return result_df
921919

922920
def _make_prompt(
923-
self, columns: List[str], user_instruction: str, output_instruction: str
921+
self, prompt_df, columns, user_instruction: str, output_instruction: str
924922
):
925-
prompt_df = self._df[columns].copy()
926923
prompt_df["prompt"] = f"{output_instruction}\n{user_instruction}\nContext: "
927924

928925
# Combine context from multiple columns.

tests/system/large/operations/test_semantics.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def test_semantics_experiment_off_raise_error():
3838
pytest.param(2, None, id="two"),
3939
pytest.param(3, None, id="three"),
4040
pytest.param(4, None, id="four"),
41-
pytest.param(5, "Year", id="two_w_cluster_column"),
42-
pytest.param(6, "Year", id="three_w_cluster_column"),
43-
pytest.param(7, "Year", id="four_w_cluster_column"),
41+
pytest.param(5, "Years", id="two_w_cluster_column"),
42+
pytest.param(6, "Years", id="three_w_cluster_column"),
43+
pytest.param(7, "Years", id="four_w_cluster_column"),
4444
],
4545
)
4646
def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
@@ -56,7 +56,7 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
5656
"Shuttle Island",
5757
"The Great Gatsby",
5858
],
59-
"Year": [1997, 2013, 2023, 2015, 2010, 2010, 2013],
59+
"Years": [1997, 2013, 2023, 2015, 2010, 2010, 2013],
6060
},
6161
session=session,
6262
)
@@ -73,6 +73,29 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
7373
pandas.testing.assert_series_equal(actual_s, expected_s, check_index_type=False)
7474

7575

76+
def test_agg_w_int_column(session, gemini_flash_model):
77+
bigframes.options.experiments.semantic_operators = True
78+
df = dataframe.DataFrame(
79+
data={
80+
"Movies": [
81+
"Killers of the Flower Moon",
82+
"The Great Gatsby",
83+
],
84+
"Years": [2023, 2013],
85+
},
86+
session=session,
87+
)
88+
instruction = "Find the {Years} Leonardo DiCaprio acted in the most movies. Answer with the year only."
89+
actual_s = df.semantics.agg(
90+
instruction,
91+
model=gemini_flash_model,
92+
).to_pandas()
93+
94+
expected_s = pd.Series(["2013 \n"], dtype=dtypes.STRING_DTYPE)
95+
expected_s.name = "Years"
96+
pandas.testing.assert_series_equal(actual_s, expected_s, check_index_type=False)
97+
98+
7699
@pytest.mark.parametrize(
77100
"instruction",
78101
[
@@ -91,11 +114,6 @@ def test_agg(session, gemini_flash_model, max_agg_rows, cluster_column):
91114
id="two_columns",
92115
marks=pytest.mark.xfail(raises=NotImplementedError),
93116
),
94-
pytest.param(
95-
"{Year}",
96-
id="invalid_type",
97-
marks=pytest.mark.xfail(raises=TypeError),
98-
),
99117
],
100118
)
101119
def test_agg_invalid_instruction_raise_error(instruction, gemini_flash_model):
@@ -207,15 +225,21 @@ def test_cluster_by_invalid_model(session, gemini_flash_model):
207225
def test_filter(session, gemini_flash_model):
208226
bigframes.options.experiments.semantic_operators = True
209227
df = dataframe.DataFrame(
210-
data={"country": ["USA", "Germany"], "city": ["Seattle", "Berlin"]},
228+
data={
229+
"country": ["USA", "Germany"],
230+
"city": ["Seattle", "Berlin"],
231+
"year": [2023, 2024],
232+
},
211233
session=session,
212234
)
213235

214236
actual_df = df.semantics.filter(
215-
"{city} is the capital of {country}", gemini_flash_model
237+
"{city} is the capital of {country} in {year}", gemini_flash_model
216238
).to_pandas()
217239

218-
expected_df = pd.DataFrame({"country": ["Germany"], "city": ["Berlin"]}, index=[1])
240+
expected_df = pd.DataFrame(
241+
{"country": ["Germany"], "city": ["Berlin"], "year": [2024]}, index=[1]
242+
)
219243
pandas.testing.assert_frame_equal(
220244
actual_df, expected_df, check_dtype=False, check_index_type=False
221245
)
@@ -282,12 +306,13 @@ def test_map(session, gemini_flash_model):
282306
data={
283307
"ingredient_1": ["Burger Bun", "Soy Bean"],
284308
"ingredient_2": ["Beef Patty", "Bittern"],
309+
"gluten-free": [True, True],
285310
},
286311
session=session,
287312
)
288313

289314
actual_df = df.semantics.map(
290-
"What is the food made from {ingredient_1} and {ingredient_2}? One word only.",
315+
"What is the {gluten-free} food made from {ingredient_1} and {ingredient_2}? One word only.",
291316
"food",
292317
gemini_flash_model,
293318
).to_pandas()
@@ -298,6 +323,7 @@ def test_map(session, gemini_flash_model):
298323
{
299324
"ingredient_1": ["Burger Bun", "Soy Bean"],
300325
"ingredient_2": ["Beef Patty", "Bittern"],
326+
"gluten-free": [True, True],
301327
"food": ["burger", "tofu"],
302328
}
303329
)
@@ -724,11 +750,6 @@ def test_sim_join_data_too_large_raises_error(session, text_embedding_generator)
724750
id="two_columns",
725751
marks=pytest.mark.xfail(raises=NotImplementedError),
726752
),
727-
pytest.param(
728-
"{ID}",
729-
id="invalid_dtypes",
730-
marks=pytest.mark.xfail(raises=TypeError),
731-
),
732753
pytest.param(
733754
"{index}",
734755
id="preserved",

0 commit comments

Comments
 (0)