@@ -381,21 +381,42 @@ def filter(self, instruction: str, model, ground_with_google_search: bool = Fals
381
381
self ._confirm_operation (len (self ._df ))
382
382
383
383
df : bigframes .dataframe .DataFrame = self ._df [columns ].copy ()
384
+ has_blob_column = False
384
385
for column in columns :
386
+ if df [column ].dtype == dtypes .OBJ_REF_DTYPE :
387
+ # Don't cast blob columns to string
388
+ has_blob_column = True
389
+ continue
390
+
385
391
if df [column ].dtype != dtypes .STRING_DTYPE :
386
392
df [column ] = df [column ].astype (dtypes .STRING_DTYPE )
387
393
388
394
user_instruction = self ._format_instruction (instruction , columns )
389
395
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
390
396
391
- results = typing .cast (
392
- bigframes .dataframe .DataFrame ,
393
- model .predict (
394
- self ._make_prompt (df , columns , user_instruction , output_instruction ),
395
- temperature = 0.0 ,
396
- ground_with_google_search = ground_with_google_search ,
397
- ),
398
- )
397
+ if has_blob_column :
398
+ results = typing .cast (
399
+ bigframes .dataframe .DataFrame ,
400
+ model .predict (
401
+ df ,
402
+ prompt = self ._make_multimodel_prompt (
403
+ df , columns , user_instruction , output_instruction
404
+ ),
405
+ temperature = 0.0 ,
406
+ ground_with_google_search = ground_with_google_search ,
407
+ ),
408
+ )
409
+ else :
410
+ results = typing .cast (
411
+ bigframes .dataframe .DataFrame ,
412
+ model .predict (
413
+ self ._make_text_prompt (
414
+ df , columns , user_instruction , output_instruction
415
+ ),
416
+ temperature = 0.0 ,
417
+ ground_with_google_search = ground_with_google_search ,
418
+ ),
419
+ )
399
420
400
421
return self ._df [
401
422
results ["ml_generate_text_llm_result" ].str .lower ().str .contains ("true" )
@@ -480,7 +501,13 @@ def map(
480
501
self ._confirm_operation (len (self ._df ))
481
502
482
503
df : bigframes .dataframe .DataFrame = self ._df [columns ].copy ()
504
+ has_blob_column = False
483
505
for column in columns :
506
+ if df [column ].dtype == dtypes .OBJ_REF_DTYPE :
507
+ # Don't cast blob columns to string
508
+ has_blob_column = True
509
+ continue
510
+
484
511
if df [column ].dtype != dtypes .STRING_DTYPE :
485
512
df [column ] = df [column ].astype (dtypes .STRING_DTYPE )
486
513
@@ -489,14 +516,29 @@ def map(
489
516
"Based on the provided contenxt, answer the following instruction:"
490
517
)
491
518
492
- results = typing .cast (
493
- bigframes .series .Series ,
494
- model .predict (
495
- self ._make_prompt (df , columns , user_instruction , output_instruction ),
496
- temperature = 0.0 ,
497
- ground_with_google_search = ground_with_google_search ,
498
- )["ml_generate_text_llm_result" ],
499
- )
519
+ if has_blob_column :
520
+ results = typing .cast (
521
+ bigframes .series .Series ,
522
+ model .predict (
523
+ df ,
524
+ prompt = self ._make_multimodel_prompt (
525
+ df , columns , user_instruction , output_instruction
526
+ ),
527
+ temperature = 0.0 ,
528
+ ground_with_google_search = ground_with_google_search ,
529
+ )["ml_generate_text_llm_result" ],
530
+ )
531
+ else :
532
+ results = typing .cast (
533
+ bigframes .series .Series ,
534
+ model .predict (
535
+ self ._make_text_prompt (
536
+ df , columns , user_instruction , output_instruction
537
+ ),
538
+ temperature = 0.0 ,
539
+ ground_with_google_search = ground_with_google_search ,
540
+ )["ml_generate_text_llm_result" ],
541
+ )
500
542
501
543
from bigframes .core .reshape .api import concat
502
544
@@ -1060,8 +1102,19 @@ def _attach_embedding(dataframe, source_column: str, embedding_column: str, mode
1060
1102
result_df [embedding_column ] = embeddings
1061
1103
return result_df
1062
1104
1063
- def _make_prompt (
1064
- self , prompt_df , columns , user_instruction : str , output_instruction : str
1105
+ @staticmethod
1106
+ def _make_multimodel_prompt (
1107
+ prompt_df , columns , user_instruction : str , output_instruction : str
1108
+ ):
1109
+ prompt = [f"{ output_instruction } \n { user_instruction } \n Context: " ]
1110
+ for col in columns :
1111
+ prompt .extend ([f"{ col } is " , prompt_df [col ]])
1112
+
1113
+ return prompt
1114
+
1115
+ @staticmethod
1116
+ def _make_text_prompt (
1117
+ prompt_df , columns , user_instruction : str , output_instruction : str
1065
1118
):
1066
1119
prompt_df ["prompt" ] = f"{ output_instruction } \n { user_instruction } \n Context: "
1067
1120
@@ -1071,7 +1124,8 @@ def _make_prompt(
1071
1124
1072
1125
return prompt_df ["prompt" ]
1073
1126
1074
- def _parse_columns (self , instruction : str ) -> List [str ]:
1127
+ @staticmethod
1128
+ def _parse_columns (instruction : str ) -> List [str ]:
1075
1129
"""Extracts column names enclosed in curly braces from the user instruction.
1076
1130
For example, _parse_columns("{city} is in {continent}") == ["city", "continent"]
1077
1131
"""
0 commit comments