Skip to content

Commit c188f49

Browse files
authored
chore: add experimental Multimodal support in Gemini (#1368)
* chore: add experimental Multimodal support in Gemini * fix * warning
1 parent aec3fe7 commit c188f49

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

bigframes/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484

8585
import bigframes.session
8686

87-
SingleItemValue = Union[bigframes.series.Series, int, float, Callable]
87+
SingleItemValue = Union[bigframes.series.Series, int, float, str, Callable]
8888

8989
LevelType = typing.Hashable
9090
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]
@@ -1953,7 +1953,7 @@ def _assign_single_item_listlike(self, k: str, v: Sequence) -> DataFrame:
19531953
result_block = result_block.drop_columns([src_col])
19541954
return DataFrame(result_block)
19551955

1956-
def _assign_scalar(self, label: str, value: Union[int, float]) -> DataFrame:
1956+
def _assign_scalar(self, label: str, value: Union[int, float, str]) -> DataFrame:
19571957
col_ids = self._block.cols_matching_label(label)
19581958

19591959
block, constant_col_id = self._block.create_constant(value, label)

bigframes/ml/llm.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Callable, cast, Literal, Mapping, Optional
19+
from typing import Callable, cast, Iterable, Literal, Mapping, Optional, Union
2020
import warnings
2121

2222
import bigframes_vendored.constants as constants
2323
from google.cloud import bigquery
2424
import typing_extensions
2525

26-
from bigframes import clients, exceptions
26+
from bigframes import clients, dtypes, exceptions
27+
import bigframes.bigquery as bbq
2728
from bigframes.core import blocks, global_session, log_adapter
2829
import bigframes.dataframe
2930
from bigframes.ml import base, core, globals, utils
31+
import bigframes.series
3032

3133
_BQML_PARAMS_MAPPING = {
3234
"max_iterations": "maxIterations",
@@ -83,6 +85,13 @@
8385
_GEMINI_1P5_PRO_002_ENDPOINT,
8486
_GEMINI_1P5_FLASH_002_ENDPOINT,
8587
)
88+
_GEMINI_MULTIMODAL_ENDPOINTS = (
89+
_GEMINI_1P5_PRO_001_ENDPOINT,
90+
_GEMINI_1P5_PRO_002_ENDPOINT,
91+
_GEMINI_1P5_FLASH_001_ENDPOINT,
92+
_GEMINI_1P5_FLASH_002_ENDPOINT,
93+
_GEMINI_2_FLASH_EXP_ENDPOINT,
94+
)
8695

8796
_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet"
8897
_CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku"
@@ -925,12 +934,13 @@ def predict(
925934
top_p: float = 1.0,
926935
ground_with_google_search: bool = False,
927936
max_retries: int = 0,
937+
prompt: Optional[Iterable[Union[str, bigframes.series.Series]]] = None,
928938
) -> bigframes.dataframe.DataFrame:
929939
"""Predict the result from input DataFrame.
930940
931941
Args:
932942
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
933-
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction.
943+
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, the "prompt" column, or created by "prompt" parameter, is used for prediction.
934944
Prompts can include preamble, questions, suggestions, instructions, or examples.
935945
936946
temperature (float, default 0.9):
@@ -966,6 +976,14 @@ def predict(
966976
max_retries (int, default 0):
967977
Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry.
968978
Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result.
979+
980+
prompt (Iterable of str or bigframes.series.Series, or None, default None):
981+
.. note::
982+
BigFrames Blob is still under experiments. It may not work and subject to change in the future.
983+
984+
Construct a prompt struct column for prediction based on the input. The input must be an Iterable that can take string literals,
985+
such as "summarize", string column(s) of X, such as X["str_col"], or blob column(s) of X, such as X["blob_col"].
986+
It creates a struct column of the items of the iterable, and use the concatenated result as the input prompt. No-op if set to None.
969987
Returns:
970988
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
971989
"""
@@ -990,7 +1008,38 @@ def predict(
9901008
f"max_retries must be larger than or equal to 0, but is {max_retries}."
9911009
)
9921010

993-
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)
1011+
session = self._bqml_model.session
1012+
(X,) = utils.batch_convert_to_dataframe(X, session=session)
1013+
1014+
if prompt:
1015+
if not bigframes.options.experiments.blob:
1016+
raise NotImplementedError()
1017+
1018+
if self.model_name not in _GEMINI_MULTIMODAL_ENDPOINTS:
1019+
raise NotImplementedError(
1020+
f"GeminiTextGenerator only supports model_name {', '.join(_GEMINI_MULTIMODAL_ENDPOINTS)} for Multimodal prompt."
1021+
)
1022+
1023+
df_prompt = X[[X.columns[0]]].rename(
1024+
columns={X.columns[0]: "bigframes_placeholder_col"}
1025+
)
1026+
for i, item in enumerate(prompt):
1027+
# must be distinct str column labels to construct a struct
1028+
if isinstance(item, str):
1029+
label = f"input_{i}"
1030+
else: # Series
1031+
label = f"input_{i}_{item.name}"
1032+
1033+
# TODO(garrettwu): remove transform to ObjRefRuntime when BQML supports ObjRef as input
1034+
if (
1035+
isinstance(item, bigframes.series.Series)
1036+
and item.dtype == dtypes.OBJ_REF_DTYPE
1037+
):
1038+
item = item.blob._get_runtime("R", with_metadata=True)
1039+
1040+
df_prompt[label] = item
1041+
df_prompt = df_prompt.drop(columns="bigframes_placeholder_col")
1042+
X["prompt"] = bbq.struct(df_prompt)
9941043

9951044
if len(X.columns) == 1:
9961045
# BQML identified the column by name

0 commit comments

Comments
 (0)