Skip to content

Commit d1b87e2

Browse files
authored
chore: Implement Semantics agg (#1059)
* chore: Implement Semantics agg * fix tests * address comments
1 parent 575a10a commit d1b87e2

File tree

3 files changed

+540
-49
lines changed

3 files changed

+540
-49
lines changed

bigframes/operations/semantics.py

Lines changed: 222 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from typing import List, Optional
1919

2020
import bigframes
21+
import bigframes.core.guid
22+
import bigframes.dtypes as dtypes
2123

2224

2325
class Semantics:
@@ -27,6 +29,171 @@ def __init__(self, df) -> None:
2729

2830
self._df = df
2931

32+
def agg(
33+
self,
34+
instruction: str,
35+
model,
36+
cluster_column: typing.Optional[str] = None,
37+
max_agg_rows: int = 10,
38+
):
39+
"""
40+
Performs an aggregation over all rows of the table.
41+
42+
This method recursively aggregates the input data to produce partial answers
43+
in parallel, until a single answer remains.
44+
45+
**Examples:**
46+
47+
>>> import bigframes.pandas as bpd
48+
>>> bpd.options.display.progress_bar = None
49+
>>> bpd.options.experiments.semantic_operators = True
50+
51+
>>> import bigframes.ml.llm as llm
52+
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
53+
54+
>>> df = bpd.DataFrame(
55+
... {
56+
... "Movies": [
57+
... "Titanic",
58+
... "The Wolf of Wall Street",
59+
... "Inception",
60+
... ],
61+
... "Year": [1997, 2013, 2010],
62+
... })
63+
>>> df.semantics.agg(
64+
... "Find the first name shared by all actors in {Movies}. One word answer.",
65+
... model=model,
66+
... )
67+
0 Leonardo
68+
<BLANKLINE>
69+
Name: Movies, dtype: string
70+
71+
Args:
72+
instruction (str):
73+
An instruction on how to map the data. This value must contain
74+
column references by name enclosed in braces.
75+
For example, to reference a column named "movies", use "{movies}" in the
76+
instruction, like: "Find actor names shared by all {movies}."
77+
78+
model (bigframes.ml.llm.GeminiTextGenerator):
79+
A GeminiTextGenerator provided by the Bigframes ML package.
80+
81+
cluster_column (Optional[str], default None):
82+
If set, aggregates each cluster before performing aggregations across
83+
clusters. Clustering based on semantic similarity can improve accuracy
84+
of the sementic aggregations.
85+
86+
max_agg_rows (int, default 10):
87+
The maxinum number of rows to be aggregated at a time.
88+
89+
Returns:
90+
bigframes.dataframe.DataFrame: A new DataFrame with the aggregated answers.
91+
92+
Raises:
93+
NotImplementedError: when the semantic operator experiment is off.
94+
ValueError: when the instruction refers to a non-existing column, or when
95+
more than one columns are referred to.
96+
"""
97+
self._validate_model(model)
98+
99+
columns = self._parse_columns(instruction)
100+
for column in columns:
101+
if column not in self._df.columns:
102+
raise ValueError(f"Column {column} not found.")
103+
if len(columns) > 1:
104+
raise NotImplementedError(
105+
"Semantic aggregations are limited to a single column."
106+
)
107+
column = columns[0]
108+
109+
if max_agg_rows <= 1:
110+
raise ValueError(
111+
f"Invalid value for `max_agg_rows`: {max_agg_rows}."
112+
"It must be greater than 1."
113+
)
114+
115+
import bigframes.bigquery as bbq
116+
import bigframes.dataframe
117+
import bigframes.series
118+
119+
df: bigframes.dataframe.DataFrame = self._df.copy()
120+
user_instruction = self._format_instruction(instruction, columns)
121+
122+
num_cluster = 1
123+
if cluster_column is not None:
124+
if cluster_column not in df.columns:
125+
raise ValueError(f"Cluster column `{cluster_column}` not found.")
126+
127+
if df[cluster_column].dtype != dtypes.INT_DTYPE:
128+
raise TypeError(
129+
"Cluster column must be an integer type, not "
130+
f"{type(df[cluster_column])}"
131+
)
132+
133+
num_cluster = len(df[cluster_column].unique())
134+
df = df.sort_values(cluster_column)
135+
else:
136+
cluster_column = bigframes.core.guid.generate_guid("pid")
137+
df[cluster_column] = 0
138+
139+
aggregation_group_id = bigframes.core.guid.generate_guid("agg")
140+
group_row_index = bigframes.core.guid.generate_guid("gid")
141+
llm_prompt = bigframes.core.guid.generate_guid("prompt")
142+
df = (
143+
df.reset_index(drop=True)
144+
.reset_index()
145+
.rename(columns={"index": aggregation_group_id})
146+
)
147+
148+
output_instruction = (
149+
"Answer user instructions using the provided context from various sources. "
150+
"Combine all relevant information into a single, concise, well-structured response. "
151+
f"Instruction: {user_instruction}.\n\n"
152+
)
153+
154+
while len(df) > 1:
155+
df[group_row_index] = (df[aggregation_group_id] % max_agg_rows + 1).astype(
156+
dtypes.STRING_DTYPE
157+
)
158+
df[aggregation_group_id] = (df[aggregation_group_id] / max_agg_rows).astype(
159+
dtypes.INT_DTYPE
160+
)
161+
df[llm_prompt] = "\t\nSource #" + df[group_row_index] + ": " + df[column]
162+
163+
if len(df) > num_cluster:
164+
# Aggregate within each partition
165+
agg_df = bbq.array_agg(
166+
df.groupby(by=[cluster_column, aggregation_group_id])
167+
)
168+
else:
169+
# Aggregate cross partitions
170+
agg_df = bbq.array_agg(df.groupby(by=[aggregation_group_id]))
171+
agg_df[cluster_column] = agg_df[cluster_column].list[0]
172+
173+
# Skip if the aggregated group only has a single item
174+
single_row_df: bigframes.series.Series = bbq.array_to_string(
175+
agg_df[agg_df[group_row_index].list.len() <= 1][column],
176+
delimiter="",
177+
)
178+
prompt_s: bigframes.series.Series = bbq.array_to_string(
179+
agg_df[agg_df[group_row_index].list.len() > 1][llm_prompt],
180+
delimiter="",
181+
)
182+
prompt_s = output_instruction + prompt_s # type:ignore
183+
184+
# Run model
185+
predict_df = typing.cast(
186+
bigframes.dataframe.DataFrame, model.predict(prompt_s)
187+
)
188+
agg_df[column] = predict_df["ml_generate_text_llm_result"].combine_first(
189+
single_row_df
190+
)
191+
192+
agg_df = agg_df.reset_index()
193+
df = agg_df[[aggregation_group_id, cluster_column, column]]
194+
195+
return df[column]
196+
30197
def filter(self, instruction: str, model):
31198
"""
32199
Filters the DataFrame with the semantics of the user instruction.
@@ -35,9 +202,7 @@ def filter(self, instruction: str, model):
35202
36203
>>> import bigframes.pandas as bpd
37204
>>> bpd.options.display.progress_bar = None
38-
39-
>>> import bigframes
40-
>>> bigframes.options.experiments.semantic_operators = True
205+
>>> bpd.options.experiments.semantic_operators = True
41206
42207
>>> import bigframes.ml.llm as llm
43208
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
@@ -68,14 +233,22 @@ def filter(self, instruction: str, model):
68233
ValueError: when the instruction refers to a non-existing column, or when no
69234
columns are referred to.
70235
"""
71-
_validate_model(model)
236+
self._validate_model(model)
237+
columns = self._parse_columns(instruction)
238+
for column in columns:
239+
if column not in self._df.columns:
240+
raise ValueError(f"Column {column} not found.")
72241

242+
user_instruction = self._format_instruction(instruction, columns)
73243
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
74244

75245
from bigframes.dataframe import DataFrame
76246

77247
results = typing.cast(
78-
DataFrame, model.predict(self._make_prompt(instruction, output_instruction))
248+
DataFrame,
249+
model.predict(
250+
self._make_prompt(columns, user_instruction, output_instruction)
251+
),
79252
)
80253

81254
return self._df[
@@ -90,9 +263,7 @@ def map(self, instruction: str, output_column: str, model):
90263
91264
>>> import bigframes.pandas as bpd
92265
>>> bpd.options.display.progress_bar = None
93-
94-
>>> import bigframes
95-
>>> bigframes.options.experiments.semantic_operators = True
266+
>>> bpd.options.experiments.semantic_operators = True
96267
97268
>>> import bigframes.ml.llm as llm
98269
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
@@ -129,8 +300,13 @@ def map(self, instruction: str, output_column: str, model):
129300
ValueError: when the instruction refers to a non-existing column, or when no
130301
columns are referred to.
131302
"""
132-
_validate_model(model)
303+
self._validate_model(model)
304+
columns = self._parse_columns(instruction)
305+
for column in columns:
306+
if column not in self._df.columns:
307+
raise ValueError(f"Column {column} not found.")
133308

309+
user_instruction = self._format_instruction(instruction, columns)
134310
output_instruction = (
135311
"Based on the provided contenxt, answer the following instruction:"
136312
)
@@ -139,34 +315,15 @@ def map(self, instruction: str, output_column: str, model):
139315

140316
results = typing.cast(
141317
Series,
142-
model.predict(self._make_prompt(instruction, output_instruction))[
143-
"ml_generate_text_llm_result"
144-
],
318+
model.predict(
319+
self._make_prompt(columns, user_instruction, output_instruction)
320+
)["ml_generate_text_llm_result"],
145321
)
146322

147323
from bigframes.core.reshape import concat
148324

149325
return concat([self._df, results.rename(output_column)], axis=1)
150326

151-
def _make_prompt(self, user_instruction: str, output_instruction: str):
152-
columns = _parse_columns(user_instruction)
153-
154-
for column in columns:
155-
if column not in self._df.columns:
156-
raise ValueError(f"Column {column} not found.")
157-
158-
# Replace column references with names.
159-
user_instruction = user_instruction.format(**{col: col for col in columns})
160-
161-
prompt_df = self._df[columns].copy()
162-
prompt_df["prompt"] = f"{output_instruction}\n{user_instruction}\nContext: "
163-
164-
# Combine context from multiple columns.
165-
for col in columns:
166-
prompt_df["prompt"] += f"{col} is `" + prompt_df[col] + "`\n"
167-
168-
return prompt_df["prompt"]
169-
170327
def join(self, other, instruction: str, model, max_rows: int = 1000):
171328
"""
172329
Joines two dataframes by applying the instruction over each pair of rows from
@@ -176,9 +333,7 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
176333
177334
>>> import bigframes.pandas as bpd
178335
>>> bpd.options.display.progress_bar = None
179-
180-
>>> import bigframes
181-
>>> bigframes.options.experiments.semantic_operators = True
336+
>>> bpd.options.experiments.semantic_operators = True
182337
183338
>>> import bigframes.ml.llm as llm
184339
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
@@ -221,7 +376,8 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
221376
Raises:
222377
ValueError if the amount of data that will be sent for LLM processing is larger than max_rows.
223378
"""
224-
_validate_model(model)
379+
self._validate_model(model)
380+
columns = self._parse_columns(instruction)
225381

226382
joined_table_rows = len(self._df) * len(other)
227383

@@ -230,8 +386,6 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
230386
f"Number of rows that need processing is {joined_table_rows}, which exceeds row limit {max_rows}."
231387
)
232388

233-
columns = _parse_columns(instruction)
234-
235389
left_columns = []
236390
right_columns = []
237391

@@ -373,18 +527,40 @@ def search(
373527

374528
return typing.cast(bigframes.dataframe.DataFrame, search_result)
375529

530+
def _make_prompt(
531+
self, columns: List[str], user_instruction: str, output_instruction: str
532+
):
533+
prompt_df = self._df[columns].copy()
534+
prompt_df["prompt"] = f"{output_instruction}\n{user_instruction}\nContext: "
535+
536+
# Combine context from multiple columns.
537+
for col in columns:
538+
prompt_df["prompt"] += f"{col} is `" + prompt_df[col] + "`\n"
539+
540+
return prompt_df["prompt"]
376541

377-
def _validate_model(model):
378-
from bigframes.ml.llm import GeminiTextGenerator
542+
def _parse_columns(self, instruction: str) -> List[str]:
543+
"""Extracts column names enclosed in curly braces from the user instruction.
544+
For example, _parse_columns("{city} is in {continent}") == ["city", "continent"]
545+
"""
546+
columns = re.findall(r"(?<!{)\{(?!{)(.*?)\}(?!\})", instruction)
379547

380-
if not isinstance(model, GeminiTextGenerator):
381-
raise ValueError("Model is not GeminiText Generator")
548+
if not columns:
549+
raise ValueError("No column references.")
382550

551+
return columns
383552

384-
def _parse_columns(instruction: str) -> List[str]:
385-
columns = re.findall(r"(?<!{)\{(?!{)(.*?)\}(?!\})", instruction)
553+
@staticmethod
554+
def _format_instruction(instruction: str, columns: List[str]) -> str:
555+
"""Extracts column names enclosed in curly braces from the user instruction.
556+
For example, `_format_instruction(["city", "continent"], "{city} is in {continent}")
557+
== "city is in continent"`
558+
"""
559+
return instruction.format(**{col: col for col in columns})
386560

387-
if not columns:
388-
raise ValueError("No column references")
561+
@staticmethod
562+
def _validate_model(model):
563+
from bigframes.ml.llm import GeminiTextGenerator
389564

390-
return columns
565+
if not isinstance(model, GeminiTextGenerator):
566+
raise ValueError("Model is not GeminiText Generator")

0 commit comments

Comments
 (0)