Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0356206
feat: Add BigQuery ML CREATE MODEL support
google-labs-jules[bot] Dec 1, 2025
54050d5
feat: Add BigQuery ML CREATE MODEL support
google-labs-jules[bot] Dec 1, 2025
fdee53f
feat: Add BigQuery ML CREATE MODEL support
google-labs-jules[bot] Dec 1, 2025
c1adfd9
feat: Add BigQuery ML CREATE MODEL support
google-labs-jules[bot] Dec 1, 2025
8b08869
feat: Add BigQuery ML CREATE MODEL support
google-labs-jules[bot] Dec 1, 2025
bfc4fdf
fix lint errors
tswast Dec 2, 2025
15d1d59
update docs
tswast Dec 2, 2025
261073d
revert sample
tswast Dec 2, 2025
50e98ff
Merge branch 'main' into create-model-support
tswast Dec 2, 2025
b809f81
feat: Add BigQuery ML CREATE MODEL support
google-labs-jules[bot] Dec 2, 2025
7df5e09
Revert "feat: Add BigQuery ML CREATE MODEL support"
tswast Dec 2, 2025
9231f56
restore docstring from jules
tswast Dec 2, 2025
f323b6b
link to sql reference
tswast Dec 2, 2025
2d5e065
add todo for other transform inputs
tswast Dec 2, 2025
c86e15a
create bbq notebook
tswast Dec 2, 2025
9ad9011
add more functions
tswast Dec 2, 2025
b355e47
add more functions
tswast Dec 2, 2025
b4e31ef
fix struct options
tswast Dec 2, 2025
a59d746
add sample notebook
tswast Dec 2, 2025
fba9326
feat: Add BigQuery ML CREATE MODEL support
google-labs-jules[bot] Dec 3, 2025
74d4fcc
Revert "feat: Add BigQuery ML CREATE MODEL support"
tswast Dec 3, 2025
68e770b
return pd.Series from `create_model`
tswast Dec 3, 2025
82d1aec
support pandas inputs
tswast Dec 3, 2025
1d88694
add unit tests
tswast Dec 3, 2025
84f4427
add less mocking
tswast Dec 3, 2025
e43636b
Merge remote-tracking branch 'origin/main' into create-model-support
tswast Dec 3, 2025
9baeb8a
skip snapshot tests where not present
tswast Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions bigframes/bigquery/_operations/ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import typing
from typing import Mapping, Optional, Union

import bigframes.core.sql.ml
import bigframes.core.log_adapter as log_adapter
import bigframes.dataframe as dataframe
import bigframes.ml.base
import bigframes.session

@log_adapter.method_logger(custom_base_name="bigquery_ml")
def create_model(
model_name: str,
*,
replace: bool = False,
if_not_exists: bool = False,
transform: Optional[list[str]] = None,
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
query: Optional[Union[dataframe.DataFrame, str]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the query argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removed the query argument from create_model.

training_data: Optional[Union[dataframe.DataFrame, str]] = None,
custom_holiday: Optional[Union[dataframe.DataFrame, str]] = None,
session: Optional[bigframes.session.Session] = None,
) -> bigframes.ml.base.BaseModel:
"""
Creates a BigQuery ML model.
"""
import bigframes.pandas as bpd

# Helper to convert DataFrame to SQL string
def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
if isinstance(df_or_sql, str):
return df_or_sql
# It's a DataFrame
sql, _, _ = df_or_sql._to_sql_query(include_index=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't include the index in the training data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Updated _to_sql to use include_index=False.

return sql
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this outside of the function body.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Moved _to_sql outside create_model.


query_statement = _to_sql(query) if query is not None else None
training_data_sql = _to_sql(training_data) if training_data is not None else None
custom_holiday_sql = _to_sql(custom_holiday) if custom_holiday is not None else None

# Determine session from DataFrames if not provided
if session is None:
# Try to get session from inputs
dfs = [obj for obj in [query, training_data, custom_holiday] if hasattr(obj, "_session")]
if dfs:
session = dfs[0]._session

sql = bigframes.core.sql.ml.create_model_ddl(
model_name=model_name,
replace=replace,
if_not_exists=if_not_exists,
transform=transform,
input_schema=input_schema,
output_schema=output_schema,
connection_name=connection_name,
options=options,
query_statement=query_statement,
training_data=training_data_sql,
custom_holiday=custom_holiday_sql,
)

if session is None:
session = bpd.get_global_session()

# Use _start_query_ml_ddl which is designed for this
session._start_query_ml_ddl(sql)

return session.read_gbq_model(model_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ends up creating a bigframes.ml model. If user uses bigquery.ml, I wonder do they still want a bigframes.ml.model, or a job id or model name is more natural?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nice thing about the bigframes.ml model object is that it allows the user to access some metadata about the model. Another alternative is maybe a pandas DataFrame with the model metadata, kinda like we do with dry run.

21 changes: 21 additions & 0 deletions bigframes/bigquery/ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module integrates BigQuery ML functions."""

from bigframes.bigquery._operations.ml import create_model

__all__ = [
"create_model",
]
File renamed without changes.
100 changes: 100 additions & 0 deletions bigframes/core/sql/ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import typing
from typing import Mapping, Optional, Union

import bigframes.core.compile.googlesql as googlesql
import bigframes.core.sql

def create_model_ddl(
model_name: str,
*,
replace: bool = False,
if_not_exists: bool = False,
transform: Optional[list[str]] = None,
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
query_statement: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove query_statement. Instead, if training_data is specified and custom_holiday is not, use training_data as query_statement currently is now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removed query_statement and updated the logic to use training_data as the main query if custom_holiday is not present.

training_data: Optional[str] = None,
custom_holiday: Optional[str] = None,
) -> str:
"""Encode the CREATE MODEL statement."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added the link.


if replace:
create = "CREATE OR REPLACE MODEL "
elif if_not_exists:
create = "CREATE MODEL IF NOT EXISTS "
else:
create = "CREATE MODEL "

ddl = f"{create}{googlesql.identifier(model_name)}\n"

# [TRANSFORM (select_list)]
if transform:
ddl += f"TRANSFORM ({', '.join(transform)})\n"

# [INPUT (field_name field_type) OUTPUT (field_name field_type)]
if input_schema:
inputs = [f"{k} {v}" for k, v in input_schema.items()]
ddl += f"INPUT ({', '.join(inputs)})\n"

if output_schema:
outputs = [f"{k} {v}" for k, v in output_schema.items()]
ddl += f"OUTPUT ({', '.join(outputs)})\n"

# [REMOTE WITH CONNECTION {connection_name | DEFAULT}]
if connection_name:
if connection_name.upper() == "DEFAULT":
ddl += "REMOTE WITH CONNECTION DEFAULT\n"
else:
ddl += f"REMOTE WITH CONNECTION {googlesql.identifier(connection_name)}\n"

# [OPTIONS(model_option_list)]
if options:
rendered_options = []
for option_name, option_value in options.items():
if isinstance(option_value, (list, tuple)):
# Handle list options like model_registry="vertex_ai"
# wait, usually options are key=value.
# if value is list, it is [val1, val2]
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
else:
rendered_val = bigframes.core.sql.simple_literal(option_value)

rendered_options.append(f"{option_name} = {rendered_val}")

ddl += f"OPTIONS({', '.join(rendered_options)})\n"

# [AS {query_statement | ( training_data AS (query_statement), custom_holiday AS (holiday_statement) )}]

if query_statement and (training_data or custom_holiday):
raise ValueError("Cannot specify both `query_statement` and (`training_data` or `custom_holiday`).")

if query_statement:
ddl += f"AS {query_statement}"
elif training_data:
# specialized AS clause
parts = []
parts.append(f"training_data AS ({training_data})")
if custom_holiday:
parts.append(f"custom_holiday AS ({custom_holiday})")

ddl += f"AS (\n {', '.join(parts)}\n)"

return ddl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE MODEL `my_project.my_dataset.my_model`
OPTIONS(model_type = 'LINEAR_REG', input_label_cols = ['label'])
AS SELECT * FROM my_table
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE MODEL IF NOT EXISTS `my_model`
OPTIONS(model_type = 'KMEANS')
AS SELECT * FROM t
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE MODEL `my_model`
OPTIONS(hidden_units = [32, 16], dropout = 0.2)
AS SELECT * FROM t
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE MODEL `my_remote_model`
INPUT (prompt STRING)
OUTPUT (content STRING)
REMOTE WITH CONNECTION `my_project.us.my_connection`
OPTIONS(endpoint = 'gemini-pro')
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE MODEL `my_remote_model`
REMOTE WITH CONNECTION DEFAULT
OPTIONS(endpoint = 'gemini-pro')
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE OR REPLACE MODEL `my_model`
OPTIONS(model_type = 'LOGISTIC_REG')
AS SELECT * FROM t
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE MODEL `my_arima_model`
OPTIONS(model_type = 'ARIMA_PLUS')
AS (
training_data AS (SELECT * FROM sales), custom_holiday AS (SELECT * FROM holidays)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE MODEL `my_model`
TRANSFORM (ML.STANDARD_SCALER(c1) OVER() AS c1_scaled, c2)
OPTIONS(model_type = 'LINEAR_REG')
AS SELECT c1, c2, label FROM t
86 changes: 86 additions & 0 deletions tests/unit/core/sql/test_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2025 Google LLC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we confident that unit tests only is enough for these cases? Otherwise we'd still rely on system tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a notebook test as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. Thank you for adding the notebook test.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import bigframes.core.sql.ml

def test_create_model_basic(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_project.my_dataset.my_model",
options={"model_type": "LINEAR_REG", "input_label_cols": ["label"]},
query_statement="SELECT * FROM my_table",
)
snapshot.assert_match(sql, "create_model_basic.sql")

def test_create_model_replace(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
replace=True,
options={"model_type": "LOGISTIC_REG"},
query_statement="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_replace.sql")

def test_create_model_if_not_exists(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
if_not_exists=True,
options={"model_type": "KMEANS"},
query_statement="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_if_not_exists.sql")

def test_create_model_transform(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
transform=["ML.STANDARD_SCALER(c1) OVER() AS c1_scaled", "c2"],
options={"model_type": "LINEAR_REG"},
query_statement="SELECT c1, c2, label FROM t",
)
snapshot.assert_match(sql, "create_model_transform.sql")

def test_create_model_remote(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_remote_model",
connection_name="my_project.us.my_connection",
options={"endpoint": "gemini-pro"},
input_schema={"prompt": "STRING"},
output_schema={"content": "STRING"},
)
snapshot.assert_match(sql, "create_model_remote.sql")

def test_create_model_remote_default(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_remote_model",
connection_name="DEFAULT",
options={"endpoint": "gemini-pro"},
)
snapshot.assert_match(sql, "create_model_remote_default.sql")

def test_create_model_training_data_and_holiday(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_arima_model",
options={"model_type": "ARIMA_PLUS"},
training_data="SELECT * FROM sales",
custom_holiday="SELECT * FROM holidays",
)
snapshot.assert_match(sql, "create_model_training_data_and_holiday.sql")

def test_create_model_list_option(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
options={"hidden_units": [32, 16], "dropout": 0.2},
query_statement="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_list_option.sql")
Loading