-
Notifications
You must be signed in to change notification settings - Fork 63
feat: add bigframes.bigquery.ml methods
#2300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0356206
54050d5
fdee53f
c1adfd9
8b08869
bfc4fdf
15d1d59
261073d
50e98ff
b809f81
7df5e09
9231f56
f323b6b
2d5e065
c86e15a
9ad9011
b355e47
b4e31ef
a59d746
fba9326
74d4fcc
68e770b
82d1aec
1d88694
84f4427
e43636b
9baeb8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import typing | ||
| from typing import Mapping, Optional, Union | ||
|
|
||
| import bigframes.core.sql.ml | ||
| import bigframes.core.log_adapter as log_adapter | ||
| import bigframes.dataframe as dataframe | ||
| import bigframes.session | ||
|
|
||
| @log_adapter.method_logger(custom_base_name="bigquery_ml") | ||
| def create_model( | ||
| model_name: str, | ||
| *, | ||
| replace: bool = False, | ||
| if_not_exists: bool = False, | ||
| transform: Optional[list[str]] = None, | ||
| input_schema: Optional[Mapping[str, str]] = None, | ||
| output_schema: Optional[Mapping[str, str]] = None, | ||
| connection_name: Optional[str] = None, | ||
| options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None, | ||
| query: Optional[Union[dataframe.DataFrame, str]] = None, | ||
| training_data: Optional[Union[dataframe.DataFrame, str]] = None, | ||
| custom_holiday: Optional[Union[dataframe.DataFrame, str]] = None, | ||
| session: Optional[bigframes.session.Session] = None, | ||
| ) -> None: | ||
| """ | ||
| Creates a BigQuery ML model. | ||
tswast marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| import bigframes.pandas as bpd | ||
|
|
||
| # Helper to convert DataFrame to SQL string | ||
| def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str: | ||
| if isinstance(df_or_sql, str): | ||
| return df_or_sql | ||
| # It's a DataFrame | ||
| sql, _, _ = df_or_sql._to_sql_query(include_index=True) | ||
|
||
| return sql | ||
|
||
|
|
||
| query_statement = _to_sql(query) if query is not None else None | ||
| training_data_sql = _to_sql(training_data) if training_data is not None else None | ||
| custom_holiday_sql = _to_sql(custom_holiday) if custom_holiday is not None else None | ||
|
|
||
| # Determine session from DataFrames if not provided | ||
| if session is None: | ||
| # Try to get session from inputs | ||
| dfs = [obj for obj in [query, training_data, custom_holiday] if hasattr(obj, "_session")] | ||
| if dfs: | ||
| session = dfs[0]._session | ||
|
|
||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name=model_name, | ||
| replace=replace, | ||
| if_not_exists=if_not_exists, | ||
| transform=transform, | ||
| input_schema=input_schema, | ||
| output_schema=output_schema, | ||
| connection_name=connection_name, | ||
| options=options, | ||
| query_statement=query_statement, | ||
| training_data=training_data_sql, | ||
| custom_holiday=custom_holiday_sql, | ||
| ) | ||
|
|
||
| if session is None: | ||
| session = bpd.get_global_session() | ||
|
|
||
| # Use _start_query_ml_ddl which is designed for this | ||
| session._start_query_ml_ddl(sql) | ||
tswast marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """This module integrates BigQuery ML functions.""" | ||
|
|
||
| from bigframes.bigquery._operations.ml import create_model | ||
|
|
||
| __all__ = [ | ||
| "create_model", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import typing | ||
| from typing import Mapping, Optional, Union | ||
|
|
||
| import bigframes.core.compile.googlesql as googlesql | ||
| import bigframes.core.sql | ||
|
|
||
| def create_model_ddl( | ||
| model_name: str, | ||
| *, | ||
| replace: bool = False, | ||
| if_not_exists: bool = False, | ||
| transform: Optional[list[str]] = None, | ||
| input_schema: Optional[Mapping[str, str]] = None, | ||
| output_schema: Optional[Mapping[str, str]] = None, | ||
| connection_name: Optional[str] = None, | ||
| options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None, | ||
| query_statement: Optional[str] = None, | ||
|
||
| training_data: Optional[str] = None, | ||
| custom_holiday: Optional[str] = None, | ||
| ) -> str: | ||
| """Encode the CREATE MODEL statement.""" | ||
|
||
|
|
||
| if replace: | ||
| create = "CREATE OR REPLACE MODEL " | ||
| elif if_not_exists: | ||
| create = "CREATE MODEL IF NOT EXISTS " | ||
| else: | ||
| create = "CREATE MODEL " | ||
|
|
||
| ddl = f"{create}{googlesql.identifier(model_name)}\n" | ||
|
|
||
| # [TRANSFORM (select_list)] | ||
| if transform: | ||
| ddl += f"TRANSFORM ({', '.join(transform)})\n" | ||
|
|
||
| # [INPUT (field_name field_type) OUTPUT (field_name field_type)] | ||
| if input_schema: | ||
| inputs = [f"{k} {v}" for k, v in input_schema.items()] | ||
| ddl += f"INPUT ({', '.join(inputs)})\n" | ||
|
|
||
| if output_schema: | ||
| outputs = [f"{k} {v}" for k, v in output_schema.items()] | ||
| ddl += f"OUTPUT ({', '.join(outputs)})\n" | ||
|
|
||
| # [REMOTE WITH CONNECTION {connection_name | DEFAULT}] | ||
| if connection_name: | ||
| if connection_name.upper() == "DEFAULT": | ||
| ddl += "REMOTE WITH CONNECTION DEFAULT\n" | ||
| else: | ||
| ddl += f"REMOTE WITH CONNECTION {googlesql.identifier(connection_name)}\n" | ||
|
|
||
| # [OPTIONS(model_option_list)] | ||
| if options: | ||
| rendered_options = [] | ||
| for option_name, option_value in options.items(): | ||
| if isinstance(option_value, (list, tuple)): | ||
| # Handle list options like model_registry="vertex_ai" | ||
| # wait, usually options are key=value. | ||
| # if value is list, it is [val1, val2] | ||
| rendered_val = bigframes.core.sql.simple_literal(list(option_value)) | ||
| else: | ||
| rendered_val = bigframes.core.sql.simple_literal(option_value) | ||
|
|
||
| rendered_options.append(f"{option_name} = {rendered_val}") | ||
|
|
||
| ddl += f"OPTIONS({', '.join(rendered_options)})\n" | ||
|
|
||
| # [AS {query_statement | ( training_data AS (query_statement), custom_holiday AS (holiday_statement) )}] | ||
|
|
||
| if query_statement and (training_data or custom_holiday): | ||
| raise ValueError("Cannot specify both `query_statement` and (`training_data` or `custom_holiday`).") | ||
|
|
||
| if query_statement: | ||
| ddl += f"AS {query_statement}" | ||
| elif training_data: | ||
| # specialized AS clause | ||
| parts = [] | ||
| parts.append(f"training_data AS ({training_data})") | ||
| if custom_holiday: | ||
| parts.append(f"custom_holiday AS ({custom_holiday})") | ||
|
|
||
| ddl += f"AS (\n {', '.join(parts)}\n)" | ||
|
|
||
| return ddl | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE MODEL `my_project.my_dataset.my_model` | ||
| OPTIONS(model_type = 'LINEAR_REG', input_label_cols = ['label']) | ||
| AS SELECT * FROM my_table |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE MODEL IF NOT EXISTS `my_model` | ||
| OPTIONS(model_type = 'KMEANS') | ||
| AS SELECT * FROM t |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE MODEL `my_model` | ||
| OPTIONS(hidden_units = [32, 16], dropout = 0.2) | ||
| AS SELECT * FROM t |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| CREATE MODEL `my_remote_model` | ||
| INPUT (prompt STRING) | ||
| OUTPUT (content STRING) | ||
| REMOTE WITH CONNECTION `my_project.us.my_connection` | ||
| OPTIONS(endpoint = 'gemini-pro') |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE MODEL `my_remote_model` | ||
| REMOTE WITH CONNECTION DEFAULT | ||
| OPTIONS(endpoint = 'gemini-pro') |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| CREATE OR REPLACE MODEL `my_model` | ||
| OPTIONS(model_type = 'LOGISTIC_REG') | ||
| AS SELECT * FROM t |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| CREATE MODEL `my_arima_model` | ||
| OPTIONS(model_type = 'ARIMA_PLUS') | ||
| AS ( | ||
| training_data AS (SELECT * FROM sales), custom_holiday AS (SELECT * FROM holidays) | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| CREATE MODEL `my_model` | ||
| TRANSFORM (ML.STANDARD_SCALER(c1) OVER() AS c1_scaled, c2) | ||
| OPTIONS(model_type = 'LINEAR_REG') | ||
| AS SELECT c1, c2, label FROM t |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| # Copyright 2025 Google LLC | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we confident that unit tests only is enough for these cases? Otherwise we'd still rely on system tests.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I created a notebook test as well.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Acknowledged. Thank you for adding the notebook test. |
||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pytest | ||
| import bigframes.core.sql.ml | ||
|
|
||
| def test_create_model_basic(snapshot): | ||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name="my_project.my_dataset.my_model", | ||
| options={"model_type": "LINEAR_REG", "input_label_cols": ["label"]}, | ||
| query_statement="SELECT * FROM my_table", | ||
| ) | ||
| snapshot.assert_match(sql, "create_model_basic.sql") | ||
|
|
||
| def test_create_model_replace(snapshot): | ||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name="my_model", | ||
| replace=True, | ||
| options={"model_type": "LOGISTIC_REG"}, | ||
| query_statement="SELECT * FROM t", | ||
| ) | ||
| snapshot.assert_match(sql, "create_model_replace.sql") | ||
|
|
||
| def test_create_model_if_not_exists(snapshot): | ||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name="my_model", | ||
| if_not_exists=True, | ||
| options={"model_type": "KMEANS"}, | ||
| query_statement="SELECT * FROM t", | ||
| ) | ||
| snapshot.assert_match(sql, "create_model_if_not_exists.sql") | ||
|
|
||
| def test_create_model_transform(snapshot): | ||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name="my_model", | ||
| transform=["ML.STANDARD_SCALER(c1) OVER() AS c1_scaled", "c2"], | ||
| options={"model_type": "LINEAR_REG"}, | ||
| query_statement="SELECT c1, c2, label FROM t", | ||
| ) | ||
| snapshot.assert_match(sql, "create_model_transform.sql") | ||
|
|
||
| def test_create_model_remote(snapshot): | ||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name="my_remote_model", | ||
| connection_name="my_project.us.my_connection", | ||
| options={"endpoint": "gemini-pro"}, | ||
| input_schema={"prompt": "STRING"}, | ||
| output_schema={"content": "STRING"}, | ||
| ) | ||
| snapshot.assert_match(sql, "create_model_remote.sql") | ||
|
|
||
| def test_create_model_remote_default(snapshot): | ||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name="my_remote_model", | ||
| connection_name="DEFAULT", | ||
| options={"endpoint": "gemini-pro"}, | ||
| ) | ||
| snapshot.assert_match(sql, "create_model_remote_default.sql") | ||
|
|
||
| def test_create_model_training_data_and_holiday(snapshot): | ||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name="my_arima_model", | ||
| options={"model_type": "ARIMA_PLUS"}, | ||
| training_data="SELECT * FROM sales", | ||
| custom_holiday="SELECT * FROM holidays", | ||
| ) | ||
| snapshot.assert_match(sql, "create_model_training_data_and_holiday.sql") | ||
|
|
||
| def test_create_model_list_option(snapshot): | ||
| sql = bigframes.core.sql.ml.create_model_ddl( | ||
| model_name="my_model", | ||
| options={"hidden_units": [32, 16], "dropout": 0.2}, | ||
| query_statement="SELECT * FROM t", | ||
| ) | ||
| snapshot.assert_match(sql, "create_model_list_option.sql") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the query argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Removed the
queryargument fromcreate_model.