Skip to content

Commit 2029d08

Browse files
authored
chore: enable multi-model input for sem_map and sem_filter (#1487)
* feat: enable multi-model input for sem_map and sem_filter * remove commented out code * fix format * polish prompt * do not use multi-model mode when inputs are all texts * use mixture of text and images for multimodel tests * add some small tests to increase coverage * use test bucket for multimodel test
1 parent 3a0dbe1 commit 2029d08

File tree

3 files changed

+288
-62
lines changed

3 files changed

+288
-62
lines changed

bigframes/operations/semantics.py

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -381,21 +381,42 @@ def filter(self, instruction: str, model, ground_with_google_search: bool = Fals
381381
self._confirm_operation(len(self._df))
382382

383383
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
384+
has_blob_column = False
384385
for column in columns:
386+
if df[column].dtype == dtypes.OBJ_REF_DTYPE:
387+
# Don't cast blob columns to string
388+
has_blob_column = True
389+
continue
390+
385391
if df[column].dtype != dtypes.STRING_DTYPE:
386392
df[column] = df[column].astype(dtypes.STRING_DTYPE)
387393

388394
user_instruction = self._format_instruction(instruction, columns)
389395
output_instruction = "Based on the provided context, reply to the following claim by only True or False:"
390396

391-
results = typing.cast(
392-
bigframes.dataframe.DataFrame,
393-
model.predict(
394-
self._make_prompt(df, columns, user_instruction, output_instruction),
395-
temperature=0.0,
396-
ground_with_google_search=ground_with_google_search,
397-
),
398-
)
397+
if has_blob_column:
398+
results = typing.cast(
399+
bigframes.dataframe.DataFrame,
400+
model.predict(
401+
df,
402+
prompt=self._make_multimodel_prompt(
403+
df, columns, user_instruction, output_instruction
404+
),
405+
temperature=0.0,
406+
ground_with_google_search=ground_with_google_search,
407+
),
408+
)
409+
else:
410+
results = typing.cast(
411+
bigframes.dataframe.DataFrame,
412+
model.predict(
413+
self._make_text_prompt(
414+
df, columns, user_instruction, output_instruction
415+
),
416+
temperature=0.0,
417+
ground_with_google_search=ground_with_google_search,
418+
),
419+
)
399420

400421
return self._df[
401422
results["ml_generate_text_llm_result"].str.lower().str.contains("true")
@@ -480,7 +501,13 @@ def map(
480501
self._confirm_operation(len(self._df))
481502

482503
df: bigframes.dataframe.DataFrame = self._df[columns].copy()
504+
has_blob_column = False
483505
for column in columns:
506+
if df[column].dtype == dtypes.OBJ_REF_DTYPE:
507+
# Don't cast blob columns to string
508+
has_blob_column = True
509+
continue
510+
484511
if df[column].dtype != dtypes.STRING_DTYPE:
485512
df[column] = df[column].astype(dtypes.STRING_DTYPE)
486513

@@ -489,14 +516,29 @@ def map(
489516
"Based on the provided contenxt, answer the following instruction:"
490517
)
491518

492-
results = typing.cast(
493-
bigframes.series.Series,
494-
model.predict(
495-
self._make_prompt(df, columns, user_instruction, output_instruction),
496-
temperature=0.0,
497-
ground_with_google_search=ground_with_google_search,
498-
)["ml_generate_text_llm_result"],
499-
)
519+
if has_blob_column:
520+
results = typing.cast(
521+
bigframes.series.Series,
522+
model.predict(
523+
df,
524+
prompt=self._make_multimodel_prompt(
525+
df, columns, user_instruction, output_instruction
526+
),
527+
temperature=0.0,
528+
ground_with_google_search=ground_with_google_search,
529+
)["ml_generate_text_llm_result"],
530+
)
531+
else:
532+
results = typing.cast(
533+
bigframes.series.Series,
534+
model.predict(
535+
self._make_text_prompt(
536+
df, columns, user_instruction, output_instruction
537+
),
538+
temperature=0.0,
539+
ground_with_google_search=ground_with_google_search,
540+
)["ml_generate_text_llm_result"],
541+
)
500542

501543
from bigframes.core.reshape.api import concat
502544

@@ -1060,8 +1102,19 @@ def _attach_embedding(dataframe, source_column: str, embedding_column: str, mode
10601102
result_df[embedding_column] = embeddings
10611103
return result_df
10621104

1063-
def _make_prompt(
1064-
self, prompt_df, columns, user_instruction: str, output_instruction: str
1105+
@staticmethod
1106+
def _make_multimodel_prompt(
1107+
prompt_df, columns, user_instruction: str, output_instruction: str
1108+
):
1109+
prompt = [f"{output_instruction}\n{user_instruction}\nContext: "]
1110+
for col in columns:
1111+
prompt.extend([f"{col} is ", prompt_df[col]])
1112+
1113+
return prompt
1114+
1115+
@staticmethod
1116+
def _make_text_prompt(
1117+
prompt_df, columns, user_instruction: str, output_instruction: str
10651118
):
10661119
prompt_df["prompt"] = f"{output_instruction}\n{user_instruction}\nContext: "
10671120

@@ -1071,7 +1124,8 @@ def _make_prompt(
10711124

10721125
return prompt_df["prompt"]
10731126

1074-
def _parse_columns(self, instruction: str) -> List[str]:
1127+
@staticmethod
1128+
def _parse_columns(instruction: str) -> List[str]:
10751129
"""Extracts column names enclosed in curly braces from the user instruction.
10761130
For example, _parse_columns("{city} is in {continent}") == ["city", "continent"]
10771131
"""

0 commit comments

Comments
 (0)