1414
1515from __future__ import annotations
1616
17- import typing
18- from typing import Mapping , Optional , TYPE_CHECKING , Union
17+ from typing import Mapping , Optional , Union
1918
2019import bigframes .core .log_adapter as log_adapter
2120import bigframes .core .sql .ml
2221import bigframes .dataframe as dataframe
23-
24- if TYPE_CHECKING :
25- import bigframes .ml .base
26- import bigframes .session
22+ import bigframes .ml .base
23+ import bigframes .session
2724
2825
2926# Helper to convert DataFrame to SQL string
@@ -35,12 +32,37 @@ def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
3532 return sql
3633
3734
35+ def _get_model_name_and_session (
36+ model : Union [bigframes .ml .base .BaseEstimator , str ],
37+ # Other dataframe arguments to extract session from
38+ * dataframes : Optional [Union [dataframe .DataFrame , str ]],
39+ ) -> tuple [str , bigframes .session .Session ]:
40+ import bigframes .pandas as bpd
41+
42+ if isinstance (model , str ):
43+ model_name = model
44+ session = None
45+ for df in dataframes :
46+ if isinstance (df , dataframe .DataFrame ):
47+ session = df ._session
48+ break
49+ if session is None :
50+ session = bpd .get_global_session ()
51+ return model_name , session
52+ else :
53+ if model ._bqml_model is None :
54+ raise ValueError ("Model must be fitted to be used in ML operations." )
55+ return model ._bqml_model .model_name , model ._bqml_model .session
56+
57+
3858@log_adapter .method_logger (custom_base_name = "bigquery_ml" )
3959def create_model (
4060 model_name : str ,
4161 * ,
4262 replace : bool = False ,
4363 if_not_exists : bool = False ,
64+ # TODO(tswast): Also support bigframes.ml transformer classes and/or
65+ # bigframes.pandas functions?
4466 transform : Optional [list [str ]] = None ,
4567 input_schema : Optional [Mapping [str , str ]] = None ,
4668 output_schema : Optional [Mapping [str , str ]] = None ,
@@ -53,6 +75,10 @@ def create_model(
5375 """
5476 Creates a BigQuery ML model.
5577
78+ See the `BigQuery ML CREATE MODEL DDL syntax
79+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create>`_
80+ for additional reference.
81+
5682 Args:
5783 model_name (str):
5884 The name of the model in BigQuery.
@@ -61,7 +87,8 @@ def create_model(
6187 if_not_exists (bool, default False):
6288 Whether to ignore the error if the model already exists.
6389 transform (list[str], optional):
64- The TRANSFORM clause, which specifies the preprocessing steps to apply to the input data.
90+ A list of SQL transformations for the TRANSFORM clause, which
91+ specifies the preprocessing steps to apply to the input data.
6592 input_schema (Mapping[str, str], optional):
6693 The INPUT clause, which specifies the schema of the input data.
6794 output_schema (Mapping[str, str], optional):
@@ -70,16 +97,16 @@ def create_model(
7097 The connection to use for the model.
7198 options (Mapping[str, Union[str, int, float, bool, list]], optional):
7299 The OPTIONS clause, which specifies the model options.
73- training_data (Union[dataframe .DataFrame, str], optional):
100+ training_data (Union[bigframes.pandas .DataFrame, str], optional):
74101 The query or DataFrame to use for training the model.
75- custom_holiday (Union[dataframe .DataFrame, str], optional):
102+ custom_holiday (Union[bigframes.pandas .DataFrame, str], optional):
76103 The query or DataFrame to use for custom holiday data.
77104 session (bigframes.session.Session, optional):
78- The BigFrames session to use. If not provided, the default session is used.
105+ The session to use. If not provided, the default session is used.
79106
80107 Returns:
81108 bigframes.ml.base.BaseEstimator:
82- The created BigFrames model.
109+ The created BigQuery ML model.
83110 """
84111 import bigframes .pandas as bpd
85112
@@ -117,3 +144,196 @@ def create_model(
117144 session ._start_query_ml_ddl (sql )
118145
119146 return session .read_gbq_model (model_name )
147+
148+
149+ @log_adapter .method_logger (custom_base_name = "bigquery_ml" )
150+ def evaluate (
151+ model : Union [bigframes .ml .base .BaseEstimator , str ],
152+ input_ : Optional [Union [dataframe .DataFrame , str ]] = None ,
153+ * ,
154+ perform_aggregation : Optional [bool ] = None ,
155+ horizon : Optional [int ] = None ,
156+ confidence_level : Optional [float ] = None ,
157+ ) -> dataframe .DataFrame :
158+ """
159+ Evaluates a BigQuery ML model.
160+
161+ See the `BigQuery ML EVALUATE function syntax
162+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-evaluate>`_
163+ for additional reference.
164+
165+ Args:
166+ model (bigframes.ml.base.BaseEstimator or str):
167+ The model to evaluate.
168+ input_ (Union[bigframes.pandas.DataFrame, str], optional):
169+ The DataFrame or query to use for evaluation. If not provided, the
170+ evaluation data from training is used.
171+ perform_aggregation (bool, optional):
172+ A BOOL value that indicates the level of evaluation for forecasting
173+ accuracy. If you specify TRUE, then the forecasting accuracy is on
174+ the time series level. If you specify FALSE, the forecasting
175+ accuracy is on the timestamp level. The default value is TRUE.
176+ horizon (int, optional):
177+ An INT64 value that specifies the number of forecasted time points
178+ against which the evaluation metrics are computed. The default value
179+ is the horizon value specified in the CREATE MODEL statement for the
180+ time series model, or 1000 if unspecified. When evaluating multiple
181+ time series at the same time, this parameter applies to each time
182+ series.
183+ confidence_level (float, optional):
184+ A FLOAT64 value that specifies the percentage of the future values
185+ that fall in the prediction interval. The default value is 0.95. The
186+ valid input range is ``[0, 1)``.
187+
188+ Returns:
189+ bigframes.pandas.DataFrame:
190+ The evaluation results.
191+ """
192+ model_name , session = _get_model_name_and_session (model , input_ )
193+ table_sql = _to_sql (input_ ) if input_ is not None else None
194+
195+ sql = bigframes .core .sql .ml .evaluate (
196+ model_name = model_name ,
197+ table = table_sql ,
198+ perform_aggregation = perform_aggregation ,
199+ horizon = horizon ,
200+ confidence_level = confidence_level ,
201+ )
202+
203+ return session .read_gbq (sql )
204+
205+
206+ @log_adapter .method_logger (custom_base_name = "bigquery_ml" )
207+ def predict (
208+ model : Union [bigframes .ml .base .BaseEstimator , str ],
209+ input_ : Union [dataframe .DataFrame , str ],
210+ * ,
211+ threshold : Optional [float ] = None ,
212+ keep_original_columns : Optional [bool ] = None ,
213+ trial_id : Optional [int ] = None ,
214+ ) -> dataframe .DataFrame :
215+ """
216+ Runs prediction on a BigQuery ML model.
217+
218+ See the `BigQuery ML PREDICT function syntax
219+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict>`_
220+ for additional reference.
221+
222+ Args:
223+ model (bigframes.ml.base.BaseEstimator or str):
224+ The model to use for prediction.
225+ input_ (Union[bigframes.pandas.DataFrame, str]):
226+ The DataFrame or query to use for prediction.
227+ threshold (float, optional):
228+ The threshold to use for classification models.
229+ keep_original_columns (bool, optional):
230+ Whether to keep the original columns in the output.
231+ trial_id (int, optional):
232+ An INT64 value that identifies the hyperparameter tuning trial that
233+ you want the function to evaluate. The function uses the optimal
234+ trial by default. Only specify this argument if you ran
235+ hyperparameter tuning when creating the model.
236+
237+ Returns:
238+ bigframes.pandas.DataFrame:
239+ The prediction results.
240+ """
241+ model_name , session = _get_model_name_and_session (model , input_ )
242+ table_sql = _to_sql (input_ )
243+
244+ sql = bigframes .core .sql .ml .predict (
245+ model_name = model_name ,
246+ table = table_sql ,
247+ threshold = threshold ,
248+ keep_original_columns = keep_original_columns ,
249+ trial_id = trial_id ,
250+ )
251+
252+ return session .read_gbq (sql )
253+
254+
255+ @log_adapter .method_logger (custom_base_name = "bigquery_ml" )
256+ def explain_predict (
257+ model : Union [bigframes .ml .base .BaseEstimator , str ],
258+ input_ : Union [dataframe .DataFrame , str ],
259+ * ,
260+ top_k_features : Optional [int ] = None ,
261+ threshold : Optional [float ] = None ,
262+ integrated_gradients_num_steps : Optional [int ] = None ,
263+ approx_feature_contrib : Optional [bool ] = None ,
264+ ) -> dataframe .DataFrame :
265+ """
266+ Runs explainable prediction on a BigQuery ML model.
267+
268+ See the `BigQuery ML EXPLAIN_PREDICT function syntax
269+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict>`_
270+ for additional reference.
271+
272+ Args:
273+ model (bigframes.ml.base.BaseEstimator or str):
274+ The model to use for prediction.
275+ input_ (Union[bigframes.pandas.DataFrame, str]):
276+ The DataFrame or query to use for prediction.
277+ top_k_features (int, optional):
278+ The number of top features to return.
279+ threshold (float, optional):
280+ The threshold for binary classification models.
281+ integrated_gradients_num_steps (int, optional):
282+ an INT64 value that specifies the number of steps to sample between
283+ the example being explained and its baseline. This value is used to
284+ approximate the integral in integrated gradients attribution
285+ methods. Increasing the value improves the precision of feature
286+ attributions, but can be slower and more computationally expensive.
287+ approx_feature_contrib (bool, optional):
288+ A BOOL value that indicates whether to use an approximate feature
289+ contribution method in the XGBoost model explanation.
290+
291+ Returns:
292+ bigframes.pandas.DataFrame:
293+ The prediction results with explanations.
294+ """
295+ model_name , session = _get_model_name_and_session (model , input_ )
296+ table_sql = _to_sql (input_ )
297+
298+ sql = bigframes .core .sql .ml .explain_predict (
299+ model_name = model_name ,
300+ table = table_sql ,
301+ top_k_features = top_k_features ,
302+ threshold = threshold ,
303+ integrated_gradients_num_steps = integrated_gradients_num_steps ,
304+ approx_feature_contrib = approx_feature_contrib ,
305+ )
306+
307+ return session .read_gbq (sql )
308+
309+
310+ @log_adapter .method_logger (custom_base_name = "bigquery_ml" )
311+ def global_explain (
312+ model : Union [bigframes .ml .base .BaseEstimator , str ],
313+ * ,
314+ class_level_explain : Optional [bool ] = None ,
315+ ) -> dataframe .DataFrame :
316+ """
317+ Gets global explanations for a BigQuery ML model.
318+
319+ See the `BigQuery ML GLOBAL_EXPLAIN function syntax
320+ <https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain>`_
321+ for additional reference.
322+
323+ Args:
324+ model (bigframes.ml.base.BaseEstimator or str):
325+ The model to get explanations from.
326+ class_level_explain (bool, optional):
327+ Whether to return class-level explanations.
328+
329+ Returns:
330+ bigframes.pandas.DataFrame:
331+ The global explanation results.
332+ """
333+ model_name , session = _get_model_name_and_session (model )
334+ sql = bigframes .core .sql .ml .global_explain (
335+ model_name = model_name ,
336+ class_level_explain = class_level_explain ,
337+ )
338+
339+ return session .read_gbq (sql )
0 commit comments