Skip to content

Commit feacaf4

Browse files
authored
chore: implement semantic join (#1051)
* chore: implement semantic join * remove redundant lines * fix column reference validation * add row size check * Fix doctest in semantics.map
1 parent 5ac217d commit feacaf4

File tree

3 files changed

+582
-15
lines changed

3 files changed

+582
-15
lines changed

bigframes/operations/semantics.py

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import re
1717
import typing
18+
from typing import List
1819

1920
import bigframes
2021

@@ -97,7 +98,7 @@ def map(self, instruction: str, output_column: str, model):
9798
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
9899
99100
>>> df = bpd.DataFrame({"ingredient_1": ["Burger Bun", "Soy Bean"], "ingredient_2": ["Beef Patty", "Bittern"]})
100-
>>> df.semantics.map("What is the food made from {ingredient_1} and {ingredient_2}? One word only.", result_column_name="food", model=model)
101+
>>> df.semantics.map("What is the food made from {ingredient_1} and {ingredient_2}? One word only.", output_column="food", model=model)
101102
ingredient_1 ingredient_2 food
102103
0 Burger Bun Beef Patty Burger
103104
<BLANKLINE>
@@ -148,11 +149,7 @@ def map(self, instruction: str, output_column: str, model):
148149
return concat([self._df, results.rename(output_column)], axis=1)
149150

150151
def _make_prompt(self, user_instruction: str, output_instruction: str):
151-
# Validate column references
152-
columns = re.findall(r"(?<!{)\{(?!{)(.*?)\}(?!\})", user_instruction)
153-
154-
if not columns:
155-
raise ValueError("No column references.")
152+
columns = _parse_columns(user_instruction)
156153

157154
for column in columns:
158155
if column not in self._df.columns:
@@ -170,9 +167,130 @@ def _make_prompt(self, user_instruction: str, output_instruction: str):
170167

171168
return prompt_df["prompt"]
172169

170+
def join(self, other, instruction: str, model, max_rows: int = 1000):
171+
"""
172+
Joines two dataframes by applying the instruction over each pair of rows from
173+
the left and right table.
174+
175+
**Examples:**
176+
177+
>>> import bigframes.pandas as bpd
178+
>>> bpd.options.display.progress_bar = None
179+
180+
>>> import bigframes
181+
>>> bigframes.options.experiments.semantic_operators = True
182+
183+
>>> import bigframes.ml.llm as llm
184+
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
185+
186+
>>> cities = bpd.DataFrame({'city': ['Seattle', 'Ottawa', 'Berlin', 'Shanghai', 'New Delhi']})
187+
>>> continents = bpd.DataFrame({'continent': ['North America', 'Africa', 'Asia']})
188+
189+
>>> cities.semantics.join(continents, "{city} is in {continent}", model)
190+
city continent
191+
0 Seattle North America
192+
1 Ottawa North America
193+
2 Shanghai Asia
194+
3 New Delhi Asia
195+
<BLANKLINE>
196+
[4 rows x 2 columns]
197+
198+
Args:
199+
other:
200+
The other dataframe.
201+
202+
instruction:
203+
An instruction on how left and right rows can be joined. This value must contain
204+
column references by name. which should be wrapped in a pair of braces.
205+
For example: "The {city} belongs to the {country}".
206+
For column names that are shared between two dataframes, you need to add "_left"
207+
and "_right" suffix for differentiation. This is especially important when you do
208+
self joins. For example: "The {employee_name_left} reports to {employee_name_right}"
209+
You must not add "_left" or "_right" suffix to non-overlapping columns.
210+
211+
model:
212+
A GeminiTextGenerator provided by Bigframes ML package.
213+
214+
max_rows:
215+
The maximum number of rows allowed to be sent to the model per call. If the result is too large, the method
216+
call will end early with an error.
217+
218+
Returns:
219+
The joined dataframe.
220+
221+
Raises:
222+
ValueError if the amount of data that will be sent for LLM processing is larger than max_rows.
223+
"""
224+
_validate_model(model)
225+
226+
joined_table_rows = len(self._df) * len(other)
227+
228+
if joined_table_rows > max_rows:
229+
raise ValueError(
230+
f"Number of rows that need processing is {joined_table_rows}, which exceeds row limit {max_rows}."
231+
)
232+
233+
columns = _parse_columns(instruction)
234+
235+
left_columns = []
236+
right_columns = []
237+
238+
for col in columns:
239+
if col in self._df.columns and col in other.columns:
240+
raise ValueError(f"Ambiguous column reference: {col}")
241+
242+
elif col in self._df.columns:
243+
left_columns.append(col)
244+
245+
elif col in other.columns:
246+
right_columns.append(col)
247+
248+
elif col.endswith("_left"):
249+
original_col_name = col[: -len("_left")]
250+
if (
251+
original_col_name in self._df.columns
252+
and original_col_name in other.columns
253+
):
254+
left_columns.append(col)
255+
elif original_col_name in self._df.columns:
256+
raise ValueError(f"Unnecessary suffix for {col}")
257+
else:
258+
raise ValueError(f"Column {col} not found")
259+
260+
elif col.endswith("_right"):
261+
original_col_name = col[: -len("_right")]
262+
if (
263+
original_col_name in self._df.columns
264+
and original_col_name in other.columns
265+
):
266+
right_columns.append(col)
267+
elif original_col_name in other.columns:
268+
raise ValueError(f"Unnecessary suffix for {col}")
269+
else:
270+
raise ValueError(f"Column {col} not found")
271+
272+
else:
273+
raise ValueError(f"Column {col} not found")
274+
275+
if not left_columns or not right_columns:
276+
raise ValueError()
277+
278+
joined_df = self._df.merge(other, how="cross", suffixes=("_left", "_right"))
279+
280+
return joined_df.semantics.filter(instruction, model).reset_index(drop=True)
281+
173282

174283
def _validate_model(model):
175284
from bigframes.ml.llm import GeminiTextGenerator
176285

177286
if not isinstance(model, GeminiTextGenerator):
178287
raise ValueError("Model is not GeminiText Generator")
288+
289+
290+
def _parse_columns(instruction: str) -> List[str]:
291+
columns = re.findall(r"(?<!{)\{(?!{)(.*?)\}(?!\})", instruction)
292+
293+
if not columns:
294+
raise ValueError("No column references")
295+
296+
return columns

0 commit comments

Comments
 (0)