18
18
from typing import List , Optional
19
19
20
20
import bigframes
21
+ import bigframes .core .guid
22
+ import bigframes .dtypes as dtypes
21
23
22
24
23
25
class Semantics :
@@ -27,6 +29,171 @@ def __init__(self, df) -> None:
27
29
28
30
self ._df = df
29
31
32
+ def agg (
33
+ self ,
34
+ instruction : str ,
35
+ model ,
36
+ cluster_column : typing .Optional [str ] = None ,
37
+ max_agg_rows : int = 10 ,
38
+ ):
39
+ """
40
+ Performs an aggregation over all rows of the table.
41
+
42
+ This method recursively aggregates the input data to produce partial answers
43
+ in parallel, until a single answer remains.
44
+
45
+ **Examples:**
46
+
47
+ >>> import bigframes.pandas as bpd
48
+ >>> bpd.options.display.progress_bar = None
49
+ >>> bpd.options.experiments.semantic_operators = True
50
+
51
+ >>> import bigframes.ml.llm as llm
52
+ >>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
53
+
54
+ >>> df = bpd.DataFrame(
55
+ ... {
56
+ ... "Movies": [
57
+ ... "Titanic",
58
+ ... "The Wolf of Wall Street",
59
+ ... "Inception",
60
+ ... ],
61
+ ... "Year": [1997, 2013, 2010],
62
+ ... })
63
+ >>> df.semantics.agg(
64
+ ... "Find the first name shared by all actors in {Movies}. One word answer.",
65
+ ... model=model,
66
+ ... )
67
+ 0 Leonardo
68
+ <BLANKLINE>
69
+ Name: Movies, dtype: string
70
+
71
+ Args:
72
+ instruction (str):
73
+ An instruction on how to map the data. This value must contain
74
+ column references by name enclosed in braces.
75
+ For example, to reference a column named "movies", use "{movies}" in the
76
+ instruction, like: "Find actor names shared by all {movies}."
77
+
78
+ model (bigframes.ml.llm.GeminiTextGenerator):
79
+ A GeminiTextGenerator provided by the Bigframes ML package.
80
+
81
+ cluster_column (Optional[str], default None):
82
+ If set, aggregates each cluster before performing aggregations across
83
+ clusters. Clustering based on semantic similarity can improve accuracy
84
+ of the sementic aggregations.
85
+
86
+ max_agg_rows (int, default 10):
87
+ The maxinum number of rows to be aggregated at a time.
88
+
89
+ Returns:
90
+ bigframes.dataframe.DataFrame: A new DataFrame with the aggregated answers.
91
+
92
+ Raises:
93
+ NotImplementedError: when the semantic operator experiment is off.
94
+ ValueError: when the instruction refers to a non-existing column, or when
95
+ more than one columns are referred to.
96
+ """
97
+ self ._validate_model (model )
98
+
99
+ columns = self ._parse_columns (instruction )
100
+ for column in columns :
101
+ if column not in self ._df .columns :
102
+ raise ValueError (f"Column { column } not found." )
103
+ if len (columns ) > 1 :
104
+ raise NotImplementedError (
105
+ "Semantic aggregations are limited to a single column."
106
+ )
107
+ column = columns [0 ]
108
+
109
+ if max_agg_rows <= 1 :
110
+ raise ValueError (
111
+ f"Invalid value for `max_agg_rows`: { max_agg_rows } ."
112
+ "It must be greater than 1."
113
+ )
114
+
115
+ import bigframes .bigquery as bbq
116
+ import bigframes .dataframe
117
+ import bigframes .series
118
+
119
+ df : bigframes .dataframe .DataFrame = self ._df .copy ()
120
+ user_instruction = self ._format_instruction (instruction , columns )
121
+
122
+ num_cluster = 1
123
+ if cluster_column is not None :
124
+ if cluster_column not in df .columns :
125
+ raise ValueError (f"Cluster column `{ cluster_column } ` not found." )
126
+
127
+ if df [cluster_column ].dtype != dtypes .INT_DTYPE :
128
+ raise TypeError (
129
+ "Cluster column must be an integer type, not "
130
+ f"{ type (df [cluster_column ])} "
131
+ )
132
+
133
+ num_cluster = len (df [cluster_column ].unique ())
134
+ df = df .sort_values (cluster_column )
135
+ else :
136
+ cluster_column = bigframes .core .guid .generate_guid ("pid" )
137
+ df [cluster_column ] = 0
138
+
139
+ aggregation_group_id = bigframes .core .guid .generate_guid ("agg" )
140
+ group_row_index = bigframes .core .guid .generate_guid ("gid" )
141
+ llm_prompt = bigframes .core .guid .generate_guid ("prompt" )
142
+ df = (
143
+ df .reset_index (drop = True )
144
+ .reset_index ()
145
+ .rename (columns = {"index" : aggregation_group_id })
146
+ )
147
+
148
+ output_instruction = (
149
+ "Answer user instructions using the provided context from various sources. "
150
+ "Combine all relevant information into a single, concise, well-structured response. "
151
+ f"Instruction: { user_instruction } .\n \n "
152
+ )
153
+
154
+ while len (df ) > 1 :
155
+ df [group_row_index ] = (df [aggregation_group_id ] % max_agg_rows + 1 ).astype (
156
+ dtypes .STRING_DTYPE
157
+ )
158
+ df [aggregation_group_id ] = (df [aggregation_group_id ] / max_agg_rows ).astype (
159
+ dtypes .INT_DTYPE
160
+ )
161
+ df [llm_prompt ] = "\t \n Source #" + df [group_row_index ] + ": " + df [column ]
162
+
163
+ if len (df ) > num_cluster :
164
+ # Aggregate within each partition
165
+ agg_df = bbq .array_agg (
166
+ df .groupby (by = [cluster_column , aggregation_group_id ])
167
+ )
168
+ else :
169
+ # Aggregate cross partitions
170
+ agg_df = bbq .array_agg (df .groupby (by = [aggregation_group_id ]))
171
+ agg_df [cluster_column ] = agg_df [cluster_column ].list [0 ]
172
+
173
+ # Skip if the aggregated group only has a single item
174
+ single_row_df : bigframes .series .Series = bbq .array_to_string (
175
+ agg_df [agg_df [group_row_index ].list .len () <= 1 ][column ],
176
+ delimiter = "" ,
177
+ )
178
+ prompt_s : bigframes .series .Series = bbq .array_to_string (
179
+ agg_df [agg_df [group_row_index ].list .len () > 1 ][llm_prompt ],
180
+ delimiter = "" ,
181
+ )
182
+ prompt_s = output_instruction + prompt_s # type:ignore
183
+
184
+ # Run model
185
+ predict_df = typing .cast (
186
+ bigframes .dataframe .DataFrame , model .predict (prompt_s )
187
+ )
188
+ agg_df [column ] = predict_df ["ml_generate_text_llm_result" ].combine_first (
189
+ single_row_df
190
+ )
191
+
192
+ agg_df = agg_df .reset_index ()
193
+ df = agg_df [[aggregation_group_id , cluster_column , column ]]
194
+
195
+ return df [column ]
196
+
30
197
def filter (self , instruction : str , model ):
31
198
"""
32
199
Filters the DataFrame with the semantics of the user instruction.
@@ -35,9 +202,7 @@ def filter(self, instruction: str, model):
35
202
36
203
>>> import bigframes.pandas as bpd
37
204
>>> bpd.options.display.progress_bar = None
38
-
39
- >>> import bigframes
40
- >>> bigframes.options.experiments.semantic_operators = True
205
+ >>> bpd.options.experiments.semantic_operators = True
41
206
42
207
>>> import bigframes.ml.llm as llm
43
208
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
@@ -68,14 +233,22 @@ def filter(self, instruction: str, model):
68
233
ValueError: when the instruction refers to a non-existing column, or when no
69
234
columns are referred to.
70
235
"""
71
- _validate_model (model )
236
+ self ._validate_model (model )
237
+ columns = self ._parse_columns (instruction )
238
+ for column in columns :
239
+ if column not in self ._df .columns :
240
+ raise ValueError (f"Column { column } not found." )
72
241
242
+ user_instruction = self ._format_instruction (instruction , columns )
73
243
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
74
244
75
245
from bigframes .dataframe import DataFrame
76
246
77
247
results = typing .cast (
78
- DataFrame , model .predict (self ._make_prompt (instruction , output_instruction ))
248
+ DataFrame ,
249
+ model .predict (
250
+ self ._make_prompt (columns , user_instruction , output_instruction )
251
+ ),
79
252
)
80
253
81
254
return self ._df [
@@ -90,9 +263,7 @@ def map(self, instruction: str, output_column: str, model):
90
263
91
264
>>> import bigframes.pandas as bpd
92
265
>>> bpd.options.display.progress_bar = None
93
-
94
- >>> import bigframes
95
- >>> bigframes.options.experiments.semantic_operators = True
266
+ >>> bpd.options.experiments.semantic_operators = True
96
267
97
268
>>> import bigframes.ml.llm as llm
98
269
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
@@ -129,8 +300,13 @@ def map(self, instruction: str, output_column: str, model):
129
300
ValueError: when the instruction refers to a non-existing column, or when no
130
301
columns are referred to.
131
302
"""
132
- _validate_model (model )
303
+ self ._validate_model (model )
304
+ columns = self ._parse_columns (instruction )
305
+ for column in columns :
306
+ if column not in self ._df .columns :
307
+ raise ValueError (f"Column { column } not found." )
133
308
309
+ user_instruction = self ._format_instruction (instruction , columns )
134
310
output_instruction = (
135
311
"Based on the provided contenxt, answer the following instruction:"
136
312
)
@@ -139,34 +315,15 @@ def map(self, instruction: str, output_column: str, model):
139
315
140
316
results = typing .cast (
141
317
Series ,
142
- model .predict (self . _make_prompt ( instruction , output_instruction ))[
143
- "ml_generate_text_llm_result"
144
- ],
318
+ model .predict (
319
+ self . _make_prompt ( columns , user_instruction , output_instruction )
320
+ )[ "ml_generate_text_llm_result" ],
145
321
)
146
322
147
323
from bigframes .core .reshape import concat
148
324
149
325
return concat ([self ._df , results .rename (output_column )], axis = 1 )
150
326
151
- def _make_prompt (self , user_instruction : str , output_instruction : str ):
152
- columns = _parse_columns (user_instruction )
153
-
154
- for column in columns :
155
- if column not in self ._df .columns :
156
- raise ValueError (f"Column { column } not found." )
157
-
158
- # Replace column references with names.
159
- user_instruction = user_instruction .format (** {col : col for col in columns })
160
-
161
- prompt_df = self ._df [columns ].copy ()
162
- prompt_df ["prompt" ] = f"{ output_instruction } \n { user_instruction } \n Context: "
163
-
164
- # Combine context from multiple columns.
165
- for col in columns :
166
- prompt_df ["prompt" ] += f"{ col } is `" + prompt_df [col ] + "`\n "
167
-
168
- return prompt_df ["prompt" ]
169
-
170
327
def join (self , other , instruction : str , model , max_rows : int = 1000 ):
171
328
"""
172
329
Joines two dataframes by applying the instruction over each pair of rows from
@@ -176,9 +333,7 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
176
333
177
334
>>> import bigframes.pandas as bpd
178
335
>>> bpd.options.display.progress_bar = None
179
-
180
- >>> import bigframes
181
- >>> bigframes.options.experiments.semantic_operators = True
336
+ >>> bpd.options.experiments.semantic_operators = True
182
337
183
338
>>> import bigframes.ml.llm as llm
184
339
>>> model = llm.GeminiTextGenerator(model_name="gemini-1.5-flash-001")
@@ -221,7 +376,8 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
221
376
Raises:
222
377
ValueError if the amount of data that will be sent for LLM processing is larger than max_rows.
223
378
"""
224
- _validate_model (model )
379
+ self ._validate_model (model )
380
+ columns = self ._parse_columns (instruction )
225
381
226
382
joined_table_rows = len (self ._df ) * len (other )
227
383
@@ -230,8 +386,6 @@ def join(self, other, instruction: str, model, max_rows: int = 1000):
230
386
f"Number of rows that need processing is { joined_table_rows } , which exceeds row limit { max_rows } ."
231
387
)
232
388
233
- columns = _parse_columns (instruction )
234
-
235
389
left_columns = []
236
390
right_columns = []
237
391
@@ -373,18 +527,40 @@ def search(
373
527
374
528
return typing .cast (bigframes .dataframe .DataFrame , search_result )
375
529
530
+ def _make_prompt (
531
+ self , columns : List [str ], user_instruction : str , output_instruction : str
532
+ ):
533
+ prompt_df = self ._df [columns ].copy ()
534
+ prompt_df ["prompt" ] = f"{ output_instruction } \n { user_instruction } \n Context: "
535
+
536
+ # Combine context from multiple columns.
537
+ for col in columns :
538
+ prompt_df ["prompt" ] += f"{ col } is `" + prompt_df [col ] + "`\n "
539
+
540
+ return prompt_df ["prompt" ]
376
541
377
- def _validate_model (model ):
378
- from bigframes .ml .llm import GeminiTextGenerator
542
+ def _parse_columns (self , instruction : str ) -> List [str ]:
543
+ """Extracts column names enclosed in curly braces from the user instruction.
544
+ For example, _parse_columns("{city} is in {continent}") == ["city", "continent"]
545
+ """
546
+ columns = re .findall (r"(?<!{)\{(?!{)(.*?)\}(?!\})" , instruction )
379
547
380
- if not isinstance ( model , GeminiTextGenerator ) :
381
- raise ValueError ("Model is not GeminiText Generator " )
548
+ if not columns :
549
+ raise ValueError ("No column references. " )
382
550
551
+ return columns
383
552
384
- def _parse_columns (instruction : str ) -> List [str ]:
385
- columns = re .findall (r"(?<!{)\{(?!{)(.*?)\}(?!\})" , instruction )
553
+ @staticmethod
554
+ def _format_instruction (instruction : str , columns : List [str ]) -> str :
555
+ """Extracts column names enclosed in curly braces from the user instruction.
556
+ For example, `_format_instruction(["city", "continent"], "{city} is in {continent}")
557
+ == "city is in continent"`
558
+ """
559
+ return instruction .format (** {col : col for col in columns })
386
560
387
- if not columns :
388
- raise ValueError ("No column references" )
561
+ @staticmethod
562
+ def _validate_model (model ):
563
+ from bigframes .ml .llm import GeminiTextGenerator
389
564
390
- return columns
565
+ if not isinstance (model , GeminiTextGenerator ):
566
+ raise ValueError ("Model is not GeminiText Generator" )
0 commit comments