@@ -98,17 +98,20 @@ def agg(
98
98
ValueError: when the instruction refers to a non-existing column, or when
99
99
more than one columns are referred to.
100
100
"""
101
- self ._validate_model (model )
101
+ import bigframes .bigquery as bbq
102
+ import bigframes .dataframe
103
+ import bigframes .series
102
104
105
+ self ._validate_model (model )
103
106
columns = self ._parse_columns (instruction )
107
+
108
+ df : bigframes .dataframe .DataFrame = self ._df .copy ()
104
109
for column in columns :
105
110
if column not in self ._df .columns :
106
111
raise ValueError (f"Column { column } not found." )
107
- if self ._df [column ].dtype != dtypes .STRING_DTYPE :
108
- raise TypeError (
109
- "Semantics aggregated column must be a string type, not "
110
- f"{ type (self ._df [column ])} "
111
- )
112
+
113
+ if df [column ].dtype != dtypes .STRING_DTYPE :
114
+ df [column ] = df [column ].astype (dtypes .STRING_DTYPE )
112
115
113
116
if len (columns ) > 1 :
114
117
raise NotImplementedError (
@@ -122,11 +125,6 @@ def agg(
122
125
"It must be greater than 1."
123
126
)
124
127
125
- import bigframes .bigquery as bbq
126
- import bigframes .dataframe
127
- import bigframes .series
128
-
129
- df : bigframes .dataframe .DataFrame = self ._df .copy ()
130
128
user_instruction = self ._format_instruction (instruction , columns )
131
129
132
130
num_cluster = 1
@@ -325,26 +323,27 @@ def filter(self, instruction: str, model):
325
323
ValueError: when the instruction refers to a non-existing column, or when no
326
324
columns are referred to.
327
325
"""
326
+ import bigframes .dataframe
327
+ import bigframes .series
328
+
328
329
self ._validate_model (model )
329
330
columns = self ._parse_columns (instruction )
330
331
for column in columns :
331
332
if column not in self ._df .columns :
332
333
raise ValueError (f"Column { column } not found." )
333
- if self . _df [ column ]. dtype != dtypes . STRING_DTYPE :
334
- raise TypeError (
335
- "Semantics aggregated column must be a string type, not "
336
- f" { type ( self . _df [column ]) } "
337
- )
334
+
335
+ df : bigframes . dataframe . DataFrame = self . _df [ columns ]. copy ()
336
+ for column in columns :
337
+ if df [column ]. dtype != dtypes . STRING_DTYPE :
338
+ df [ column ] = df [ column ]. astype ( dtypes . STRING_DTYPE )
338
339
339
340
user_instruction = self ._format_instruction (instruction , columns )
340
341
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
341
342
342
- from bigframes .dataframe import DataFrame
343
-
344
343
results = typing .cast (
345
- DataFrame ,
344
+ bigframes . dataframe . DataFrame ,
346
345
model .predict (
347
- self ._make_prompt (columns , user_instruction , output_instruction ),
346
+ self ._make_prompt (df , columns , user_instruction , output_instruction ),
348
347
temperature = 0.0 ,
349
348
),
350
349
)
@@ -398,28 +397,29 @@ def map(self, instruction: str, output_column: str, model):
398
397
ValueError: when the instruction refers to a non-existing column, or when no
399
398
columns are referred to.
400
399
"""
400
+ import bigframes .dataframe
401
+ import bigframes .series
402
+
401
403
self ._validate_model (model )
402
404
columns = self ._parse_columns (instruction )
403
405
for column in columns :
404
406
if column not in self ._df .columns :
405
407
raise ValueError (f"Column { column } not found." )
406
- if self . _df [ column ]. dtype != dtypes . STRING_DTYPE :
407
- raise TypeError (
408
- "Semantics aggregated column must be a string type, not "
409
- f" { type ( self . _df [column ]) } "
410
- )
408
+
409
+ df : bigframes . dataframe . DataFrame = self . _df [ columns ]. copy ()
410
+ for column in columns :
411
+ if df [column ]. dtype != dtypes . STRING_DTYPE :
412
+ df [ column ] = df [ column ]. astype ( dtypes . STRING_DTYPE )
411
413
412
414
user_instruction = self ._format_instruction (instruction , columns )
413
415
output_instruction = (
414
416
"Based on the provided contenxt, answer the following instruction:"
415
417
)
416
418
417
- from bigframes .series import Series
418
-
419
419
results = typing .cast (
420
- Series ,
420
+ bigframes . series . Series ,
421
421
model .predict (
422
- self ._make_prompt (columns , user_instruction , output_instruction ),
422
+ self ._make_prompt (df , columns , user_instruction , output_instruction ),
423
423
temperature = 0.0 ,
424
424
)["ml_generate_text_llm_result" ],
425
425
)
@@ -683,6 +683,9 @@ def top_k(self, instruction: str, model, k=10):
683
683
ValueError: when the instruction refers to a non-existing column, or when no
684
684
columns are referred to.
685
685
"""
686
+ import bigframes .dataframe
687
+ import bigframes .series
688
+
686
689
self ._validate_model (model )
687
690
columns = self ._parse_columns (instruction )
688
691
for column in columns :
@@ -692,12 +695,12 @@ def top_k(self, instruction: str, model, k=10):
692
695
raise NotImplementedError (
693
696
"Semantic aggregations are limited to a single column."
694
697
)
698
+
699
+ df : bigframes .dataframe .DataFrame = self ._df [columns ].copy ()
695
700
column = columns [0 ]
696
- if self ._df [column ].dtype != dtypes .STRING_DTYPE :
697
- raise TypeError (
698
- "Referred column must be a string type, not "
699
- f"{ type (self ._df [column ])} "
700
- )
701
+ if df [column ].dtype != dtypes .STRING_DTYPE :
702
+ df [column ] = df [column ].astype (dtypes .STRING_DTYPE )
703
+
701
704
# `index` is reserved for the `reset_index` below.
702
705
if column == "index" :
703
706
raise ValueError (
@@ -709,12 +712,7 @@ def top_k(self, instruction: str, model, k=10):
709
712
710
713
user_instruction = self ._format_instruction (instruction , columns )
711
714
712
- import bigframes .dataframe
713
- import bigframes .series
714
-
715
- df : bigframes .dataframe .DataFrame = self ._df [columns ].copy ()
716
715
n = df .shape [0 ]
717
-
718
716
if k >= n :
719
717
return df
720
718
@@ -762,17 +760,17 @@ def _topk_partition(
762
760
763
761
# Random pivot selection for improved average quickselect performance.
764
762
pending_df = df [df [status_column ].isna ()]
765
- pivot_iloc = np .random .randint (0 , pending_df .shape [0 ] - 1 )
763
+ pivot_iloc = np .random .randint (0 , pending_df .shape [0 ])
766
764
pivot_index = pending_df .iloc [pivot_iloc ]["index" ]
767
765
pivot_df = pending_df [pending_df ["index" ] == pivot_index ]
768
766
769
767
# Build a prompt to compare the pivot item's relevance to other pending items.
770
768
prompt_s = pending_df [pending_df ["index" ] != pivot_index ][column ]
771
769
prompt_s = (
772
770
f"{ output_instruction } \n \n Question: { user_instruction } \n "
773
- + "\n Document 1: "
771
+ + f "\n Document 1: { column } "
774
772
+ pivot_df .iloc [0 ][column ]
775
- + "\n Document 2: "
773
+ + f "\n Document 2: { column } "
776
774
+ prompt_s # type:ignore
777
775
)
778
776
@@ -920,9 +918,8 @@ def _attach_embedding(dataframe, source_column: str, embedding_column: str, mode
920
918
return result_df
921
919
922
920
def _make_prompt (
923
- self , columns : List [ str ] , user_instruction : str , output_instruction : str
921
+ self , prompt_df , columns , user_instruction : str , output_instruction : str
924
922
):
925
- prompt_df = self ._df [columns ].copy ()
926
923
prompt_df ["prompt" ] = f"{ output_instruction } \n { user_instruction } \n Context: "
927
924
928
925
# Combine context from multiple columns.
0 commit comments