Skip to content

Commit 6199023

Browse files
authored
feat: add GeminiTextGenerator.predict structured output (#1653)
* feat: add GeminiTextGenerator.predict structured output * test * fix tests
1 parent cd7fbde commit 6199023

File tree

9 files changed

+339
-244
lines changed

9 files changed

+339
-244
lines changed

bigframes/ml/base.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
import abc
25-
from typing import Callable, cast, Mapping, Optional, TypeVar, Union
25+
from typing import cast, Optional, TypeVar, Union
2626
import warnings
2727

2828
import bigframes_vendored.sklearn.base
@@ -244,18 +244,12 @@ def fit(
244244

245245

246246
class RetriableRemotePredictor(BaseEstimator):
247-
@property
248-
@abc.abstractmethod
249-
def _predict_func(self) -> Callable[[bpd.DataFrame, Mapping], bpd.DataFrame]:
250-
pass
251-
252-
@property
253-
@abc.abstractmethod
254-
def _status_col(self) -> str:
255-
pass
256-
257247
def _predict_and_retry(
258-
self, X: bpd.DataFrame, options: Mapping, max_retries: int
248+
self,
249+
bqml_model_predict_tvf: core.BqmlModel.TvfDef,
250+
X: bpd.DataFrame,
251+
options: dict,
252+
max_retries: int,
259253
) -> bpd.DataFrame:
260254
assert self._bqml_model is not None
261255

@@ -269,9 +263,9 @@ def _predict_and_retry(
269263
warnings.warn(msg, category=RuntimeWarning)
270264
break
271265

272-
df = self._predict_func(df_fail, options)
266+
df = bqml_model_predict_tvf.tvf(self._bqml_model, df_fail, options)
273267

274-
success = df[self._status_col].str.len() == 0
268+
success = df[bqml_model_predict_tvf.status_col].str.len() == 0
275269
df_succ = df[success]
276270
df_fail = df[~success]
277271

bigframes/ml/core.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import dataclasses
1920
import datetime
2021
from typing import Callable, cast, Iterable, Mapping, Optional, Union
2122
import uuid
@@ -44,6 +45,11 @@ class BqmlModel(BaseBqml):
4445
BigQuery DataFrames ML.
4546
"""
4647

48+
@dataclasses.dataclass
49+
class TvfDef:
50+
tvf: Callable[[BqmlModel, bpd.DataFrame, dict], bpd.DataFrame]
51+
status_col: str
52+
4753
def __init__(self, session: bigframes.Session, model: bigquery.Model):
4854
self._session = session
4955
self._model = model
@@ -159,8 +165,9 @@ def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
159165
def generate_text(
160166
self,
161167
input_data: bpd.DataFrame,
162-
options: Mapping[str, int | float],
168+
options: dict[str, Union[int, float, bool]],
163169
) -> bpd.DataFrame:
170+
options["flatten_json_output"] = True
164171
return self._apply_ml_tvf(
165172
input_data,
166173
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_text(
@@ -169,11 +176,14 @@ def generate_text(
169176
),
170177
)
171178

179+
generate_text_tvf = TvfDef(generate_text, "ml_generate_text_status")
180+
172181
def generate_embedding(
173182
self,
174183
input_data: bpd.DataFrame,
175-
options: Mapping[str, int | float],
184+
options: dict[str, Union[int, float, bool]],
176185
) -> bpd.DataFrame:
186+
options["flatten_json_output"] = True
177187
return self._apply_ml_tvf(
178188
input_data,
179189
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_embedding(
@@ -182,6 +192,23 @@ def generate_embedding(
182192
),
183193
)
184194

195+
generate_embedding_tvf = TvfDef(generate_embedding, "ml_generate_embedding_status")
196+
197+
def generate_table(
198+
self,
199+
input_data: bpd.DataFrame,
200+
options: dict[str, Union[int, float, bool, Mapping]],
201+
) -> bpd.DataFrame:
202+
return self._apply_ml_tvf(
203+
input_data,
204+
lambda source_sql: self._model_manipulation_sql_generator.ai_generate_table(
205+
source_sql=source_sql,
206+
struct_options=options,
207+
),
208+
)
209+
210+
generate_table_tvf = TvfDef(generate_table, "status")
211+
185212
def detect_anomalies(
186213
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
187214
) -> bpd.DataFrame:

bigframes/ml/globals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
_BASE_SQL_GENERATOR = sql.BaseSqlGenerator()
2020
_BQML_MODEL_FACTORY = core.BqmlModelFactory()
2121

22-
_SUPPORTED_DTYPES = (
22+
_REMOTE_MODEL_SUPPORTED_DTYPES = (
2323
"bool",
2424
"string",
2525
"int64",

bigframes/ml/imported.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def __init__(
216216
self,
217217
model_path: str,
218218
*,
219-
input: Mapping[str, str] = {},
220-
output: Mapping[str, str] = {},
219+
input: Optional[Mapping[str, str]] = None,
220+
output: Optional[Mapping[str, str]] = None,
221221
session: Optional[bigframes.session.Session] = None,
222222
):
223223
self.session = session or bpd.get_global_session()
@@ -234,20 +234,23 @@ def _create_bqml_model(self):
234234
return self._bqml_model_factory.create_imported_model(
235235
session=self.session, options=options
236236
)
237-
else:
238-
for io in (self.input, self.output):
239-
for v in io.values():
240-
if v not in globals._SUPPORTED_DTYPES:
241-
raise ValueError(
242-
f"field_type {v} is not supported. We only support {', '.join(globals._SUPPORTED_DTYPES)}."
243-
)
244-
245-
return self._bqml_model_factory.create_xgboost_imported_model(
246-
session=self.session,
247-
input=self.input,
248-
output=self.output,
249-
options=options,
250-
)
237+
if not self.input or not self.output:
238+
raise ValueError("input and output must both or neigher be set.")
239+
self.input = {
240+
k: utils.standardize_type(v, globals._REMOTE_MODEL_SUPPORTED_DTYPES)
241+
for k, v in self.input.items()
242+
}
243+
self.output = {
244+
k: utils.standardize_type(v, globals._REMOTE_MODEL_SUPPORTED_DTYPES)
245+
for k, v in self.output.items()
246+
}
247+
248+
return self._bqml_model_factory.create_xgboost_imported_model(
249+
session=self.session,
250+
input=self.input,
251+
output=self.output,
252+
options=options,
253+
)
251254

252255
@classmethod
253256
def _from_bq(

bigframes/ml/llm.py

Lines changed: 44 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Callable, cast, Iterable, Literal, Mapping, Optional, Union
19+
from typing import cast, Iterable, Literal, Mapping, Optional, Union
2020
import warnings
2121

2222
import bigframes_vendored.constants as constants
@@ -92,10 +92,6 @@
9292
_CLAUDE_3_OPUS_ENDPOINT,
9393
)
9494

95-
96-
_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
97-
_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status"
98-
9995
_MODEL_NOT_SUPPORTED_WARNING = (
10096
"Model name '{model_name}' is not supported. "
10197
"We are currently aware of the following models: {known_models}. "
@@ -193,18 +189,6 @@ def _from_bq(
193189
model._bqml_model = core.BqmlModel(session, bq_model)
194190
return model
195191

196-
@property
197-
def _predict_func(
198-
self,
199-
) -> Callable[
200-
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
201-
]:
202-
return self._bqml_model.generate_embedding
203-
204-
@property
205-
def _status_col(self) -> str:
206-
return _ML_GENERATE_EMBEDDING_STATUS
207-
208192
def predict(
209193
self, X: utils.ArrayType, *, max_retries: int = 0
210194
) -> bigframes.dataframe.DataFrame:
@@ -233,11 +217,14 @@ def predict(
233217
col_label = cast(blocks.Label, X.columns[0])
234218
X = X.rename(columns={col_label: "content"})
235219

236-
options = {
237-
"flatten_json_output": True,
238-
}
220+
options: dict = {}
239221

240-
return self._predict_and_retry(X, options=options, max_retries=max_retries)
222+
return self._predict_and_retry(
223+
core.BqmlModel.generate_embedding_tvf,
224+
X,
225+
options=options,
226+
max_retries=max_retries,
227+
)
241228

242229
def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator:
243230
"""Save the model to BigQuery.
@@ -339,18 +326,6 @@ def _from_bq(
339326
model._bqml_model = core.BqmlModel(session, bq_model)
340327
return model
341328

342-
@property
343-
def _predict_func(
344-
self,
345-
) -> Callable[
346-
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
347-
]:
348-
return self._bqml_model.generate_embedding
349-
350-
@property
351-
def _status_col(self) -> str:
352-
return _ML_GENERATE_EMBEDDING_STATUS
353-
354329
def predict(
355330
self, X: utils.ArrayType, *, max_retries: int = 0
356331
) -> bigframes.dataframe.DataFrame:
@@ -384,11 +359,14 @@ def predict(
384359
if X["content"].dtype == dtypes.OBJ_REF_DTYPE:
385360
X["content"] = X["content"].blob._get_runtime("R", with_metadata=True)
386361

387-
options = {
388-
"flatten_json_output": True,
389-
}
362+
options: dict = {}
390363

391-
return self._predict_and_retry(X, options=options, max_retries=max_retries)
364+
return self._predict_and_retry(
365+
core.BqmlModel.generate_embedding_tvf,
366+
X,
367+
options=options,
368+
max_retries=max_retries,
369+
)
392370

393371
def to_gbq(
394372
self, model_name: str, replace: bool = False
@@ -533,18 +511,6 @@ def _bqml_options(self) -> dict:
533511
}
534512
return options
535513

536-
@property
537-
def _predict_func(
538-
self,
539-
) -> Callable[
540-
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
541-
]:
542-
return self._bqml_model.generate_text
543-
544-
@property
545-
def _status_col(self) -> str:
546-
return _ML_GENERATE_TEXT_STATUS
547-
548514
def fit(
549515
self,
550516
X: utils.ArrayType,
@@ -596,6 +562,7 @@ def predict(
596562
ground_with_google_search: bool = False,
597563
max_retries: int = 0,
598564
prompt: Optional[Iterable[Union[str, bigframes.series.Series]]] = None,
565+
output_schema: Optional[Mapping[str, str]] = None,
599566
) -> bigframes.dataframe.DataFrame:
600567
"""Predict the result from input DataFrame.
601568
@@ -645,6 +612,9 @@ def predict(
645612
Construct a prompt struct column for prediction based on the input. The input must be an Iterable that can take string literals,
646613
such as "summarize", string column(s) of X, such as X["str_col"], or blob column(s) of X, such as X["blob_col"].
647614
It creates a struct column of the items of the iterable, and use the concatenated result as the input prompt. No-op if set to None.
615+
output_schema (Mapping[str, str] or None, default None):
616+
The schema used to generate structured output as a bigframes DataFrame. The schema is a string key-value pair of <column_name>:<type>.
617+
Supported types are int64, float64, bool and string. If None, output text result.
648618
Returns:
649619
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
650620
"""
@@ -707,16 +677,31 @@ def predict(
707677
col_label = cast(blocks.Label, X.columns[0])
708678
X = X.rename(columns={col_label: "prompt"})
709679

710-
options = {
680+
options: dict = {
711681
"temperature": temperature,
712682
"max_output_tokens": max_output_tokens,
713-
"top_k": top_k,
683+
# "top_k": top_k, # TODO(garrettwu): the option is deprecated in Gemini 1.5 forward.
714684
"top_p": top_p,
715-
"flatten_json_output": True,
716685
"ground_with_google_search": ground_with_google_search,
717686
}
687+
if output_schema:
688+
output_schema = {
689+
k: utils.standardize_type(v) for k, v in output_schema.items()
690+
}
691+
options["output_schema"] = output_schema
692+
return self._predict_and_retry(
693+
core.BqmlModel.generate_table_tvf,
694+
X,
695+
options=options,
696+
max_retries=max_retries,
697+
)
718698

719-
return self._predict_and_retry(X, options=options, max_retries=max_retries)
699+
return self._predict_and_retry(
700+
core.BqmlModel.generate_text_tvf,
701+
X,
702+
options=options,
703+
max_retries=max_retries,
704+
)
720705

721706
def score(
722707
self,
@@ -916,18 +901,6 @@ def _bqml_options(self) -> dict:
916901
}
917902
return options
918903

919-
@property
920-
def _predict_func(
921-
self,
922-
) -> Callable[
923-
[bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame
924-
]:
925-
return self._bqml_model.generate_text
926-
927-
@property
928-
def _status_col(self) -> str:
929-
return _ML_GENERATE_TEXT_STATUS
930-
931904
def predict(
932905
self,
933906
X: utils.ArrayType,
@@ -1000,10 +973,14 @@ def predict(
1000973
"max_output_tokens": max_output_tokens,
1001974
"top_k": top_k,
1002975
"top_p": top_p,
1003-
"flatten_json_output": True,
1004976
}
1005977

1006-
return self._predict_and_retry(X, options=options, max_retries=max_retries)
978+
return self._predict_and_retry(
979+
core.BqmlModel.generate_text_tvf,
980+
X,
981+
options=options,
982+
max_retries=max_retries,
983+
)
1007984

1008985
def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator:
1009986
"""Save the model to BigQuery.

0 commit comments

Comments
 (0)