|
16 | 16 |
|
17 | 17 | from __future__ import annotations
|
18 | 18 |
|
19 |
| -from typing import Callable, cast, Iterable, Literal, Mapping, Optional, Union |
| 19 | +from typing import cast, Iterable, Literal, Mapping, Optional, Union |
20 | 20 | import warnings
|
21 | 21 |
|
22 | 22 | import bigframes_vendored.constants as constants
|
|
92 | 92 | _CLAUDE_3_OPUS_ENDPOINT,
|
93 | 93 | )
|
94 | 94 |
|
95 |
| - |
96 |
| -_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" |
97 |
| -_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status" |
98 |
| - |
99 | 95 | _MODEL_NOT_SUPPORTED_WARNING = (
|
100 | 96 | "Model name '{model_name}' is not supported. "
|
101 | 97 | "We are currently aware of the following models: {known_models}. "
|
@@ -193,18 +189,6 @@ def _from_bq(
|
193 | 189 | model._bqml_model = core.BqmlModel(session, bq_model)
|
194 | 190 | return model
|
195 | 191 |
|
196 |
| - @property |
197 |
| - def _predict_func( |
198 |
| - self, |
199 |
| - ) -> Callable[ |
200 |
| - [bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame |
201 |
| - ]: |
202 |
| - return self._bqml_model.generate_embedding |
203 |
| - |
204 |
| - @property |
205 |
| - def _status_col(self) -> str: |
206 |
| - return _ML_GENERATE_EMBEDDING_STATUS |
207 |
| - |
208 | 192 | def predict(
|
209 | 193 | self, X: utils.ArrayType, *, max_retries: int = 0
|
210 | 194 | ) -> bigframes.dataframe.DataFrame:
|
@@ -233,11 +217,14 @@ def predict(
|
233 | 217 | col_label = cast(blocks.Label, X.columns[0])
|
234 | 218 | X = X.rename(columns={col_label: "content"})
|
235 | 219 |
|
236 |
| - options = { |
237 |
| - "flatten_json_output": True, |
238 |
| - } |
| 220 | + options: dict = {} |
239 | 221 |
|
240 |
| - return self._predict_and_retry(X, options=options, max_retries=max_retries) |
| 222 | + return self._predict_and_retry( |
| 223 | + core.BqmlModel.generate_embedding_tvf, |
| 224 | + X, |
| 225 | + options=options, |
| 226 | + max_retries=max_retries, |
| 227 | + ) |
241 | 228 |
|
242 | 229 | def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator:
|
243 | 230 | """Save the model to BigQuery.
|
@@ -339,18 +326,6 @@ def _from_bq(
|
339 | 326 | model._bqml_model = core.BqmlModel(session, bq_model)
|
340 | 327 | return model
|
341 | 328 |
|
342 |
| - @property |
343 |
| - def _predict_func( |
344 |
| - self, |
345 |
| - ) -> Callable[ |
346 |
| - [bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame |
347 |
| - ]: |
348 |
| - return self._bqml_model.generate_embedding |
349 |
| - |
350 |
| - @property |
351 |
| - def _status_col(self) -> str: |
352 |
| - return _ML_GENERATE_EMBEDDING_STATUS |
353 |
| - |
354 | 329 | def predict(
|
355 | 330 | self, X: utils.ArrayType, *, max_retries: int = 0
|
356 | 331 | ) -> bigframes.dataframe.DataFrame:
|
@@ -384,11 +359,14 @@ def predict(
|
384 | 359 | if X["content"].dtype == dtypes.OBJ_REF_DTYPE:
|
385 | 360 | X["content"] = X["content"].blob._get_runtime("R", with_metadata=True)
|
386 | 361 |
|
387 |
| - options = { |
388 |
| - "flatten_json_output": True, |
389 |
| - } |
| 362 | + options: dict = {} |
390 | 363 |
|
391 |
| - return self._predict_and_retry(X, options=options, max_retries=max_retries) |
| 364 | + return self._predict_and_retry( |
| 365 | + core.BqmlModel.generate_embedding_tvf, |
| 366 | + X, |
| 367 | + options=options, |
| 368 | + max_retries=max_retries, |
| 369 | + ) |
392 | 370 |
|
393 | 371 | def to_gbq(
|
394 | 372 | self, model_name: str, replace: bool = False
|
@@ -533,18 +511,6 @@ def _bqml_options(self) -> dict:
|
533 | 511 | }
|
534 | 512 | return options
|
535 | 513 |
|
536 |
| - @property |
537 |
| - def _predict_func( |
538 |
| - self, |
539 |
| - ) -> Callable[ |
540 |
| - [bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame |
541 |
| - ]: |
542 |
| - return self._bqml_model.generate_text |
543 |
| - |
544 |
| - @property |
545 |
| - def _status_col(self) -> str: |
546 |
| - return _ML_GENERATE_TEXT_STATUS |
547 |
| - |
548 | 514 | def fit(
|
549 | 515 | self,
|
550 | 516 | X: utils.ArrayType,
|
@@ -596,6 +562,7 @@ def predict(
|
596 | 562 | ground_with_google_search: bool = False,
|
597 | 563 | max_retries: int = 0,
|
598 | 564 | prompt: Optional[Iterable[Union[str, bigframes.series.Series]]] = None,
|
| 565 | + output_schema: Optional[Mapping[str, str]] = None, |
599 | 566 | ) -> bigframes.dataframe.DataFrame:
|
600 | 567 | """Predict the result from input DataFrame.
|
601 | 568 |
|
@@ -645,6 +612,9 @@ def predict(
|
645 | 612 | Construct a prompt struct column for prediction based on the input. The input must be an Iterable that can take string literals,
|
646 | 613 | such as "summarize", string column(s) of X, such as X["str_col"], or blob column(s) of X, such as X["blob_col"].
|
647 | 614 | It creates a struct column of the items of the iterable, and use the concatenated result as the input prompt. No-op if set to None.
|
| 615 | + output_schema (Mapping[str, str] or None, default None): |
| 616 | + The schema used to generate structured output as a bigframes DataFrame. The schema is a string key-value pair of <column_name>:<type>. |
| 617 | + Supported types are int64, float64, bool and string. If None, output text result. |
648 | 618 | Returns:
|
649 | 619 | bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
|
650 | 620 | """
|
@@ -707,16 +677,31 @@ def predict(
|
707 | 677 | col_label = cast(blocks.Label, X.columns[0])
|
708 | 678 | X = X.rename(columns={col_label: "prompt"})
|
709 | 679 |
|
710 |
| - options = { |
| 680 | + options: dict = { |
711 | 681 | "temperature": temperature,
|
712 | 682 | "max_output_tokens": max_output_tokens,
|
713 |
| - "top_k": top_k, |
| 683 | + # "top_k": top_k, # TODO(garrettwu): the option is deprecated in Gemini 1.5 forward. |
714 | 684 | "top_p": top_p,
|
715 |
| - "flatten_json_output": True, |
716 | 685 | "ground_with_google_search": ground_with_google_search,
|
717 | 686 | }
|
| 687 | + if output_schema: |
| 688 | + output_schema = { |
| 689 | + k: utils.standardize_type(v) for k, v in output_schema.items() |
| 690 | + } |
| 691 | + options["output_schema"] = output_schema |
| 692 | + return self._predict_and_retry( |
| 693 | + core.BqmlModel.generate_table_tvf, |
| 694 | + X, |
| 695 | + options=options, |
| 696 | + max_retries=max_retries, |
| 697 | + ) |
718 | 698 |
|
719 |
| - return self._predict_and_retry(X, options=options, max_retries=max_retries) |
| 699 | + return self._predict_and_retry( |
| 700 | + core.BqmlModel.generate_text_tvf, |
| 701 | + X, |
| 702 | + options=options, |
| 703 | + max_retries=max_retries, |
| 704 | + ) |
720 | 705 |
|
721 | 706 | def score(
|
722 | 707 | self,
|
@@ -916,18 +901,6 @@ def _bqml_options(self) -> dict:
|
916 | 901 | }
|
917 | 902 | return options
|
918 | 903 |
|
919 |
| - @property |
920 |
| - def _predict_func( |
921 |
| - self, |
922 |
| - ) -> Callable[ |
923 |
| - [bigframes.dataframe.DataFrame, Mapping], bigframes.dataframe.DataFrame |
924 |
| - ]: |
925 |
| - return self._bqml_model.generate_text |
926 |
| - |
927 |
| - @property |
928 |
| - def _status_col(self) -> str: |
929 |
| - return _ML_GENERATE_TEXT_STATUS |
930 |
| - |
931 | 904 | def predict(
|
932 | 905 | self,
|
933 | 906 | X: utils.ArrayType,
|
@@ -1000,10 +973,14 @@ def predict(
|
1000 | 973 | "max_output_tokens": max_output_tokens,
|
1001 | 974 | "top_k": top_k,
|
1002 | 975 | "top_p": top_p,
|
1003 |
| - "flatten_json_output": True, |
1004 | 976 | }
|
1005 | 977 |
|
1006 |
| - return self._predict_and_retry(X, options=options, max_retries=max_retries) |
| 978 | + return self._predict_and_retry( |
| 979 | + core.BqmlModel.generate_text_tvf, |
| 980 | + X, |
| 981 | + options=options, |
| 982 | + max_retries=max_retries, |
| 983 | + ) |
1007 | 984 |
|
1008 | 985 | def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator:
|
1009 | 986 | """Save the model to BigQuery.
|
|
0 commit comments