-
Notifications
You must be signed in to change notification settings - Fork 63
feat: add bigframes.bigquery.ml methods
#2300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
0356206
54050d5
fdee53f
c1adfd9
8b08869
bfc4fdf
15d1d59
261073d
50e98ff
b809f81
7df5e09
9231f56
f323b6b
2d5e065
c86e15a
9ad9011
b355e47
b4e31ef
a59d746
fba9326
74d4fcc
68e770b
82d1aec
1d88694
84f4427
e43636b
9baeb8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Mapping, Optional, TYPE_CHECKING, Union | ||
|
|
||
| import bigframes.core.log_adapter as log_adapter | ||
| import bigframes.core.sql.ml | ||
| import bigframes.dataframe as dataframe | ||
|
|
||
| if TYPE_CHECKING: | ||
| import bigframes.ml.base | ||
| import bigframes.session | ||
|
|
||
|
|
||
| # Helper to convert DataFrame to SQL string | ||
| def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str: | ||
| if isinstance(df_or_sql, str): | ||
| return df_or_sql | ||
| # It's a DataFrame | ||
| sql, _, _ = df_or_sql._to_sql_query(include_index=False) | ||
| return sql | ||
|
|
||
|
|
||
| @log_adapter.method_logger(custom_base_name="bigquery_ml") | ||
| def create_model( | ||
| model_name: str, | ||
| *, | ||
| replace: bool = False, | ||
| if_not_exists: bool = False, | ||
| transform: Optional[list[str]] = None, | ||
| input_schema: Optional[Mapping[str, str]] = None, | ||
| output_schema: Optional[Mapping[str, str]] = None, | ||
| connection_name: Optional[str] = None, | ||
| options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None, | ||
| training_data: Optional[Union[dataframe.DataFrame, str]] = None, | ||
| custom_holiday: Optional[Union[dataframe.DataFrame, str]] = None, | ||
| session: Optional[bigframes.session.Session] = None, | ||
| ) -> bigframes.ml.base.BaseEstimator: | ||
| """ | ||
| Creates a BigQuery ML model. | ||
| """ | ||
| import bigframes.pandas as bpd | ||
|
|
||
| training_data_sql = _to_sql(training_data) if training_data is not None else None | ||
| custom_holiday_sql = _to_sql(custom_holiday) if custom_holiday is not None else None | ||
|
|
||
| # Determine session from DataFrames if not provided | ||
| if session is None: | ||
| # Try to get session from inputs | ||
| dfs = [ | ||
| obj | ||
| for obj in [training_data, custom_holiday] | ||
| if isinstance(obj, dataframe.DataFrame) | ||
| ] | ||
| if dfs: | ||
| session = dfs[0]._session | ||
|
|
||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name=model_name, | ||
| replace=replace, | ||
| if_not_exists=if_not_exists, | ||
| transform=transform, | ||
| input_schema=input_schema, | ||
| output_schema=output_schema, | ||
| connection_name=connection_name, | ||
| options=options, | ||
| training_data=training_data_sql, | ||
| custom_holiday=custom_holiday_sql, | ||
| ) | ||
|
|
||
| if session is None: | ||
| session = bpd.get_global_session() | ||
|
|
||
| # Use _start_query_ml_ddl which is designed for this | ||
| session._start_query_ml_ddl(sql) | ||
tswast marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return session.read_gbq_model(model_name) | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """This module exposes `BigQuery ML | ||
| <https://docs.cloud.google.com/bigquery/docs/bqml-introduction>`_ functions | ||
| by directly mapping to the equivalent function names in SQL syntax. | ||
|
|
||
| For an interface more familiar to Scikit-Learn users, see :mod:`bigframes.ml`. | ||
| """ | ||
|
|
||
| from bigframes.bigquery._operations.ml import create_model | ||
|
|
||
| __all__ = [ | ||
| "create_model", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Mapping, Optional, Union | ||
|
|
||
| import bigframes.core.compile.googlesql as googlesql | ||
| import bigframes.core.sql | ||
|
|
||
|
|
||
| def create_model_ddl( | ||
| model_name: str, | ||
| *, | ||
| replace: bool = False, | ||
| if_not_exists: bool = False, | ||
| transform: Optional[list[str]] = None, | ||
| input_schema: Optional[Mapping[str, str]] = None, | ||
| output_schema: Optional[Mapping[str, str]] = None, | ||
| connection_name: Optional[str] = None, | ||
| options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None, | ||
| training_data: Optional[str] = None, | ||
| custom_holiday: Optional[str] = None, | ||
| ) -> str: | ||
| """Encode the CREATE MODEL statement. | ||
|
|
||
| See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create for reference. | ||
| """ | ||
|
|
||
| if replace: | ||
| create = "CREATE OR REPLACE MODEL " | ||
| elif if_not_exists: | ||
| create = "CREATE MODEL IF NOT EXISTS " | ||
| else: | ||
| create = "CREATE MODEL " | ||
|
|
||
| ddl = f"{create}{googlesql.identifier(model_name)}\n" | ||
|
|
||
| # [TRANSFORM (select_list)] | ||
| if transform: | ||
| ddl += f"TRANSFORM ({', '.join(transform)})\n" | ||
|
|
||
| # [INPUT (field_name field_type) OUTPUT (field_name field_type)] | ||
| if input_schema: | ||
| inputs = [f"{k} {v}" for k, v in input_schema.items()] | ||
| ddl += f"INPUT ({', '.join(inputs)})\n" | ||
|
|
||
| if output_schema: | ||
| outputs = [f"{k} {v}" for k, v in output_schema.items()] | ||
| ddl += f"OUTPUT ({', '.join(outputs)})\n" | ||
|
|
||
| # [REMOTE WITH CONNECTION {connection_name | DEFAULT}] | ||
| if connection_name: | ||
| if connection_name.upper() == "DEFAULT": | ||
| ddl += "REMOTE WITH CONNECTION DEFAULT\n" | ||
| else: | ||
| ddl += f"REMOTE WITH CONNECTION {googlesql.identifier(connection_name)}\n" | ||
|
|
||
| # [OPTIONS(model_option_list)] | ||
| if options: | ||
| rendered_options = [] | ||
| for option_name, option_value in options.items(): | ||
| if isinstance(option_value, (list, tuple)): | ||
| # Handle list options like model_registry="vertex_ai" | ||
| # wait, usually options are key=value. | ||
| # if value is list, it is [val1, val2] | ||
| rendered_val = bigframes.core.sql.simple_literal(list(option_value)) | ||
| else: | ||
| rendered_val = bigframes.core.sql.simple_literal(option_value) | ||
|
|
||
| rendered_options.append(f"{option_name} = {rendered_val}") | ||
|
|
||
| ddl += f"OPTIONS({', '.join(rendered_options)})\n" | ||
|
|
||
| # [AS {query_statement | ( training_data AS (query_statement), custom_holiday AS (holiday_statement) )}] | ||
|
|
||
| if training_data: | ||
| if custom_holiday: | ||
| # When custom_holiday is present, we need named clauses | ||
| parts = [] | ||
| parts.append(f"training_data AS ({training_data})") | ||
| parts.append(f"custom_holiday AS ({custom_holiday})") | ||
| ddl += f"AS (\n {', '.join(parts)}\n)" | ||
| else: | ||
| # Just training_data is treated as the query_statement | ||
| ddl += f"AS {training_data}" | ||
|
|
||
| return ddl |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ packages. | |
| bigframes._config | ||
| bigframes.bigquery | ||
| bigframes.bigquery.ai | ||
| bigframes.bigquery.ml | ||
| bigframes.enums | ||
| bigframes.exceptions | ||
| bigframes.geopandas | ||
|
|
@@ -26,6 +27,8 @@ scikit-learn. | |
| .. autosummary:: | ||
| :toctree: api | ||
|
|
||
| bigframes.ml | ||
| bigframes.ml.base | ||
|
||
| bigframes.ml.cluster | ||
| bigframes.ml.compose | ||
| bigframes.ml.decomposition | ||
|
|
@@ -35,6 +38,7 @@ scikit-learn. | |
| bigframes.ml.impute | ||
| bigframes.ml.linear_model | ||
| bigframes.ml.llm | ||
| bigframes.ml.metrics | ||
| bigframes.ml.model_selection | ||
| bigframes.ml.pipeline | ||
| bigframes.ml.preprocessing | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE MODEL `my_project.my_dataset.my_model` | ||
| OPTIONS(model_type = 'LINEAR_REG', input_label_cols = ['label']) | ||
| AS SELECT * FROM my_table |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE MODEL IF NOT EXISTS `my_model` | ||
| OPTIONS(model_type = 'KMEANS') | ||
| AS SELECT * FROM t |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE MODEL `my_model` | ||
| OPTIONS(hidden_units = [32, 16], dropout = 0.2) | ||
| AS SELECT * FROM t |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| CREATE MODEL `my_remote_model` | ||
| INPUT (prompt STRING) | ||
| OUTPUT (content STRING) | ||
| REMOTE WITH CONNECTION `my_project.us.my_connection` | ||
| OPTIONS(endpoint = 'gemini-pro') |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE MODEL `my_remote_model` | ||
| REMOTE WITH CONNECTION DEFAULT | ||
| OPTIONS(endpoint = 'gemini-pro') |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE OR REPLACE MODEL `my_model` | ||
| OPTIONS(model_type = 'LOGISTIC_REG') | ||
| AS SELECT * FROM t |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| CREATE MODEL `my_arima_model` | ||
| OPTIONS(model_type = 'ARIMA_PLUS') | ||
| AS ( | ||
| training_data AS (SELECT * FROM sales), custom_holiday AS (SELECT * FROM holidays) | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| CREATE MODEL `my_model` | ||
| TRANSFORM (ML.STANDARD_SCALER(c1) OVER() AS c1_scaled, c2) | ||
| OPTIONS(model_type = 'LINEAR_REG') | ||
| AS SELECT c1, c2, label FROM t |
Uh oh!
There was an error while loading. Please reload this page.