Skip to content

Commit 74d4fcc

Browse files
committed
Revert "feat: Add BigQuery ML CREATE MODEL support"
This reverts commit fba9326.
1 parent fba9326 commit 74d4fcc

File tree

30 files changed

+2726
-46
lines changed

30 files changed

+2726
-46
lines changed

bigframes/bigquery/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import sys
2020

21-
from bigframes.bigquery import ai
21+
from bigframes.bigquery import ai, ml
2222
from bigframes.bigquery._operations.approx_agg import approx_top_count
2323
from bigframes.bigquery._operations.array import (
2424
array_agg,
@@ -157,4 +157,5 @@
157157
"struct",
158158
# Modules / SQL namespaces
159159
"ai",
160+
"ml",
160161
]

bigframes/bigquery/_operations/ml.py

Lines changed: 231 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,13 @@
1414

1515
from __future__ import annotations
1616

17-
import typing
18-
from typing import Mapping, Optional, TYPE_CHECKING, Union
17+
from typing import Mapping, Optional, Union
1918

2019
import bigframes.core.log_adapter as log_adapter
2120
import bigframes.core.sql.ml
2221
import bigframes.dataframe as dataframe
23-
24-
if TYPE_CHECKING:
25-
import bigframes.ml.base
26-
import bigframes.session
22+
import bigframes.ml.base
23+
import bigframes.session
2724

2825

2926
# Helper to convert DataFrame to SQL string
@@ -35,12 +32,37 @@ def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
3532
return sql
3633

3734

35+
def _get_model_name_and_session(
36+
model: Union[bigframes.ml.base.BaseEstimator, str],
37+
# Other dataframe arguments to extract session from
38+
*dataframes: Optional[Union[dataframe.DataFrame, str]],
39+
) -> tuple[str, bigframes.session.Session]:
40+
import bigframes.pandas as bpd
41+
42+
if isinstance(model, str):
43+
model_name = model
44+
session = None
45+
for df in dataframes:
46+
if isinstance(df, dataframe.DataFrame):
47+
session = df._session
48+
break
49+
if session is None:
50+
session = bpd.get_global_session()
51+
return model_name, session
52+
else:
53+
if model._bqml_model is None:
54+
raise ValueError("Model must be fitted to be used in ML operations.")
55+
return model._bqml_model.model_name, model._bqml_model.session
56+
57+
3858
@log_adapter.method_logger(custom_base_name="bigquery_ml")
3959
def create_model(
4060
model_name: str,
4161
*,
4262
replace: bool = False,
4363
if_not_exists: bool = False,
64+
# TODO(tswast): Also support bigframes.ml transformer classes and/or
65+
# bigframes.pandas functions?
4466
transform: Optional[list[str]] = None,
4567
input_schema: Optional[Mapping[str, str]] = None,
4668
output_schema: Optional[Mapping[str, str]] = None,
@@ -53,6 +75,10 @@ def create_model(
5375
"""
5476
Creates a BigQuery ML model.
5577
78+
See the `BigQuery ML CREATE MODEL DDL syntax
79+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create>`_
80+
for additional reference.
81+
5682
Args:
5783
model_name (str):
5884
The name of the model in BigQuery.
@@ -61,7 +87,8 @@ def create_model(
6187
if_not_exists (bool, default False):
6288
Whether to ignore the error if the model already exists.
6389
transform (list[str], optional):
64-
The TRANSFORM clause, which specifies the preprocessing steps to apply to the input data.
90+
A list of SQL transformations for the TRANSFORM clause, which
91+
specifies the preprocessing steps to apply to the input data.
6592
input_schema (Mapping[str, str], optional):
6693
The INPUT clause, which specifies the schema of the input data.
6794
output_schema (Mapping[str, str], optional):
@@ -70,16 +97,16 @@ def create_model(
7097
The connection to use for the model.
7198
options (Mapping[str, Union[str, int, float, bool, list]], optional):
7299
The OPTIONS clause, which specifies the model options.
73-
training_data (Union[dataframe.DataFrame, str], optional):
100+
training_data (Union[bigframes.pandas.DataFrame, str], optional):
74101
The query or DataFrame to use for training the model.
75-
custom_holiday (Union[dataframe.DataFrame, str], optional):
102+
custom_holiday (Union[bigframes.pandas.DataFrame, str], optional):
76103
The query or DataFrame to use for custom holiday data.
77104
session (bigframes.session.Session, optional):
78-
The BigFrames session to use. If not provided, the default session is used.
105+
The session to use. If not provided, the default session is used.
79106
80107
Returns:
81108
bigframes.ml.base.BaseEstimator:
82-
The created BigFrames model.
109+
The created BigQuery ML model.
83110
"""
84111
import bigframes.pandas as bpd
85112

@@ -117,3 +144,196 @@ def create_model(
117144
session._start_query_ml_ddl(sql)
118145

119146
return session.read_gbq_model(model_name)
147+
148+
149+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
150+
def evaluate(
151+
model: Union[bigframes.ml.base.BaseEstimator, str],
152+
input_: Optional[Union[dataframe.DataFrame, str]] = None,
153+
*,
154+
perform_aggregation: Optional[bool] = None,
155+
horizon: Optional[int] = None,
156+
confidence_level: Optional[float] = None,
157+
) -> dataframe.DataFrame:
158+
"""
159+
Evaluates a BigQuery ML model.
160+
161+
See the `BigQuery ML EVALUATE function syntax
162+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-evaluate>`_
163+
for additional reference.
164+
165+
Args:
166+
model (bigframes.ml.base.BaseEstimator or str):
167+
The model to evaluate.
168+
input_ (Union[bigframes.pandas.DataFrame, str], optional):
169+
The DataFrame or query to use for evaluation. If not provided, the
170+
evaluation data from training is used.
171+
perform_aggregation (bool, optional):
172+
A BOOL value that indicates the level of evaluation for forecasting
173+
accuracy. If you specify TRUE, then the forecasting accuracy is on
174+
the time series level. If you specify FALSE, the forecasting
175+
accuracy is on the timestamp level. The default value is TRUE.
176+
horizon (int, optional):
177+
An INT64 value that specifies the number of forecasted time points
178+
against which the evaluation metrics are computed. The default value
179+
is the horizon value specified in the CREATE MODEL statement for the
180+
time series model, or 1000 if unspecified. When evaluating multiple
181+
time series at the same time, this parameter applies to each time
182+
series.
183+
confidence_level (float, optional):
184+
A FLOAT64 value that specifies the percentage of the future values
185+
that fall in the prediction interval. The default value is 0.95. The
186+
valid input range is ``[0, 1)``.
187+
188+
Returns:
189+
bigframes.pandas.DataFrame:
190+
The evaluation results.
191+
"""
192+
model_name, session = _get_model_name_and_session(model, input_)
193+
table_sql = _to_sql(input_) if input_ is not None else None
194+
195+
sql = bigframes.core.sql.ml.evaluate(
196+
model_name=model_name,
197+
table=table_sql,
198+
perform_aggregation=perform_aggregation,
199+
horizon=horizon,
200+
confidence_level=confidence_level,
201+
)
202+
203+
return session.read_gbq(sql)
204+
205+
206+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
207+
def predict(
208+
model: Union[bigframes.ml.base.BaseEstimator, str],
209+
input_: Union[dataframe.DataFrame, str],
210+
*,
211+
threshold: Optional[float] = None,
212+
keep_original_columns: Optional[bool] = None,
213+
trial_id: Optional[int] = None,
214+
) -> dataframe.DataFrame:
215+
"""
216+
Runs prediction on a BigQuery ML model.
217+
218+
See the `BigQuery ML PREDICT function syntax
219+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict>`_
220+
for additional reference.
221+
222+
Args:
223+
model (bigframes.ml.base.BaseEstimator or str):
224+
The model to use for prediction.
225+
input_ (Union[bigframes.pandas.DataFrame, str]):
226+
The DataFrame or query to use for prediction.
227+
threshold (float, optional):
228+
The threshold to use for classification models.
229+
keep_original_columns (bool, optional):
230+
Whether to keep the original columns in the output.
231+
trial_id (int, optional):
232+
An INT64 value that identifies the hyperparameter tuning trial that
233+
you want the function to evaluate. The function uses the optimal
234+
trial by default. Only specify this argument if you ran
235+
hyperparameter tuning when creating the model.
236+
237+
Returns:
238+
bigframes.pandas.DataFrame:
239+
The prediction results.
240+
"""
241+
model_name, session = _get_model_name_and_session(model, input_)
242+
table_sql = _to_sql(input_)
243+
244+
sql = bigframes.core.sql.ml.predict(
245+
model_name=model_name,
246+
table=table_sql,
247+
threshold=threshold,
248+
keep_original_columns=keep_original_columns,
249+
trial_id=trial_id,
250+
)
251+
252+
return session.read_gbq(sql)
253+
254+
255+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
256+
def explain_predict(
257+
model: Union[bigframes.ml.base.BaseEstimator, str],
258+
input_: Union[dataframe.DataFrame, str],
259+
*,
260+
top_k_features: Optional[int] = None,
261+
threshold: Optional[float] = None,
262+
integrated_gradients_num_steps: Optional[int] = None,
263+
approx_feature_contrib: Optional[bool] = None,
264+
) -> dataframe.DataFrame:
265+
"""
266+
Runs explainable prediction on a BigQuery ML model.
267+
268+
See the `BigQuery ML EXPLAIN_PREDICT function syntax
269+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict>`_
270+
for additional reference.
271+
272+
Args:
273+
model (bigframes.ml.base.BaseEstimator or str):
274+
The model to use for prediction.
275+
input_ (Union[bigframes.pandas.DataFrame, str]):
276+
The DataFrame or query to use for prediction.
277+
top_k_features (int, optional):
278+
The number of top features to return.
279+
threshold (float, optional):
280+
The threshold for binary classification models.
281+
integrated_gradients_num_steps (int, optional):
282+
an INT64 value that specifies the number of steps to sample between
283+
the example being explained and its baseline. This value is used to
284+
approximate the integral in integrated gradients attribution
285+
methods. Increasing the value improves the precision of feature
286+
attributions, but can be slower and more computationally expensive.
287+
approx_feature_contrib (bool, optional):
288+
A BOOL value that indicates whether to use an approximate feature
289+
contribution method in the XGBoost model explanation.
290+
291+
Returns:
292+
bigframes.pandas.DataFrame:
293+
The prediction results with explanations.
294+
"""
295+
model_name, session = _get_model_name_and_session(model, input_)
296+
table_sql = _to_sql(input_)
297+
298+
sql = bigframes.core.sql.ml.explain_predict(
299+
model_name=model_name,
300+
table=table_sql,
301+
top_k_features=top_k_features,
302+
threshold=threshold,
303+
integrated_gradients_num_steps=integrated_gradients_num_steps,
304+
approx_feature_contrib=approx_feature_contrib,
305+
)
306+
307+
return session.read_gbq(sql)
308+
309+
310+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
311+
def global_explain(
312+
model: Union[bigframes.ml.base.BaseEstimator, str],
313+
*,
314+
class_level_explain: Optional[bool] = None,
315+
) -> dataframe.DataFrame:
316+
"""
317+
Gets global explanations for a BigQuery ML model.
318+
319+
See the `BigQuery ML GLOBAL_EXPLAIN function syntax
320+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain>`_
321+
for additional reference.
322+
323+
Args:
324+
model (bigframes.ml.base.BaseEstimator or str):
325+
The model to get explanations from.
326+
class_level_explain (bool, optional):
327+
Whether to return class-level explanations.
328+
329+
Returns:
330+
bigframes.pandas.DataFrame:
331+
The global explanation results.
332+
"""
333+
model_name, session = _get_model_name_and_session(model)
334+
sql = bigframes.core.sql.ml.global_explain(
335+
model_name=model_name,
336+
class_level_explain=class_level_explain,
337+
)
338+
339+
return session.read_gbq(sql)

bigframes/bigquery/ml.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""This module integrates BigQuery ML functions."""
15+
"""This module exposes `BigQuery ML
16+
<https://docs.cloud.google.com/bigquery/docs/bqml-introduction>`_ functions
17+
by directly mapping to the equivalent function names in SQL syntax.
1618
17-
from bigframes.bigquery._operations.ml import create_model
19+
For an interface more familiar to Scikit-Learn users, see :mod:`bigframes.ml`.
20+
"""
21+
22+
from bigframes.bigquery._operations.ml import (
23+
create_model,
24+
evaluate,
25+
explain_predict,
26+
global_explain,
27+
predict,
28+
)
1829

1930
__all__ = [
2031
"create_model",
32+
"evaluate",
33+
"predict",
34+
"explain_predict",
35+
"global_explain",
2136
]

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
107107
sge.If(this=sge.convert(key), true=sge.convert(value))
108108
for key, value in op.mappings
109109
],
110+
default=expr.expr,
110111
)
111112

112113

0 commit comments

Comments
 (0)