Skip to content

Commit 0356206

Browse files
feat: Add BigQuery ML CREATE MODEL support
- Refactor `bigframes.core.sql` to a package. - Add `bigframes.core.sql.ml` for DDL generation. - Add `bigframes.bigquery.ml` module with `create_model` function. - Add unit tests for SQL generation. - Use `_start_query_ml_ddl` for execution.
1 parent 33a211e commit 0356206

File tree

13 files changed

+319
-0
lines changed

13 files changed

+319
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
from typing import Mapping, Optional, Union
19+
20+
import bigframes.core.sql.ml
21+
import bigframes.core.log_adapter as log_adapter
22+
import bigframes.dataframe as dataframe
23+
import bigframes.session
24+
25+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
26+
def create_model(
27+
model_name: str,
28+
*,
29+
replace: bool = False,
30+
if_not_exists: bool = False,
31+
transform: Optional[list[str]] = None,
32+
input_schema: Optional[Mapping[str, str]] = None,
33+
output_schema: Optional[Mapping[str, str]] = None,
34+
connection_name: Optional[str] = None,
35+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
36+
query: Optional[Union[dataframe.DataFrame, str]] = None,
37+
training_data: Optional[Union[dataframe.DataFrame, str]] = None,
38+
custom_holiday: Optional[Union[dataframe.DataFrame, str]] = None,
39+
session: Optional[bigframes.session.Session] = None,
40+
) -> None:
41+
"""
42+
Creates a BigQuery ML model.
43+
"""
44+
import bigframes.pandas as bpd
45+
46+
# Helper to convert DataFrame to SQL string
47+
def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
48+
if isinstance(df_or_sql, str):
49+
return df_or_sql
50+
# It's a DataFrame
51+
sql, _, _ = df_or_sql._to_sql_query(include_index=True)
52+
return sql
53+
54+
query_statement = _to_sql(query) if query is not None else None
55+
training_data_sql = _to_sql(training_data) if training_data is not None else None
56+
custom_holiday_sql = _to_sql(custom_holiday) if custom_holiday is not None else None
57+
58+
# Determine session from DataFrames if not provided
59+
if session is None:
60+
# Try to get session from inputs
61+
dfs = [obj for obj in [query, training_data, custom_holiday] if hasattr(obj, "_session")]
62+
if dfs:
63+
session = dfs[0]._session
64+
65+
sql = bigframes.core.sql.ml.create_model_ddl(
66+
model_name=model_name,
67+
replace=replace,
68+
if_not_exists=if_not_exists,
69+
transform=transform,
70+
input_schema=input_schema,
71+
output_schema=output_schema,
72+
connection_name=connection_name,
73+
options=options,
74+
query_statement=query_statement,
75+
training_data=training_data_sql,
76+
custom_holiday=custom_holiday_sql,
77+
)
78+
79+
if session is None:
80+
session = bpd.get_global_session()
81+
82+
# Use _start_query_ml_ddl which is designed for this
83+
session._start_query_ml_ddl(sql)

bigframes/bigquery/ml.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""This module integrates BigQuery ML functions."""
16+
17+
from bigframes.bigquery._operations.ml import create_model
18+
19+
__all__ = [
20+
"create_model",
21+
]

bigframes/core/sql/ml.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import typing
18+
from typing import Mapping, Optional, Union
19+
20+
import bigframes.core.compile.googlesql as googlesql
21+
import bigframes.core.sql
22+
23+
def create_model_ddl(
24+
model_name: str,
25+
*,
26+
replace: bool = False,
27+
if_not_exists: bool = False,
28+
transform: Optional[list[str]] = None,
29+
input_schema: Optional[Mapping[str, str]] = None,
30+
output_schema: Optional[Mapping[str, str]] = None,
31+
connection_name: Optional[str] = None,
32+
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
33+
query_statement: Optional[str] = None,
34+
training_data: Optional[str] = None,
35+
custom_holiday: Optional[str] = None,
36+
) -> str:
37+
"""Encode the CREATE MODEL statement."""
38+
39+
if replace:
40+
create = "CREATE OR REPLACE MODEL "
41+
elif if_not_exists:
42+
create = "CREATE MODEL IF NOT EXISTS "
43+
else:
44+
create = "CREATE MODEL "
45+
46+
ddl = f"{create}{googlesql.identifier(model_name)}\n"
47+
48+
# [TRANSFORM (select_list)]
49+
if transform:
50+
ddl += f"TRANSFORM ({', '.join(transform)})\n"
51+
52+
# [INPUT (field_name field_type) OUTPUT (field_name field_type)]
53+
if input_schema:
54+
inputs = [f"{k} {v}" for k, v in input_schema.items()]
55+
ddl += f"INPUT ({', '.join(inputs)})\n"
56+
57+
if output_schema:
58+
outputs = [f"{k} {v}" for k, v in output_schema.items()]
59+
ddl += f"OUTPUT ({', '.join(outputs)})\n"
60+
61+
# [REMOTE WITH CONNECTION {connection_name | DEFAULT}]
62+
if connection_name:
63+
if connection_name.upper() == "DEFAULT":
64+
ddl += "REMOTE WITH CONNECTION DEFAULT\n"
65+
else:
66+
ddl += f"REMOTE WITH CONNECTION {googlesql.identifier(connection_name)}\n"
67+
68+
# [OPTIONS(model_option_list)]
69+
if options:
70+
rendered_options = []
71+
for option_name, option_value in options.items():
72+
if isinstance(option_value, (list, tuple)):
73+
# Handle list options like model_registry="vertex_ai"
74+
# wait, usually options are key=value.
75+
# if value is list, it is [val1, val2]
76+
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
77+
else:
78+
rendered_val = bigframes.core.sql.simple_literal(option_value)
79+
80+
rendered_options.append(f"{option_name} = {rendered_val}")
81+
82+
ddl += f"OPTIONS({', '.join(rendered_options)})\n"
83+
84+
# [AS {query_statement | ( training_data AS (query_statement), custom_holiday AS (holiday_statement) )}]
85+
86+
if query_statement and (training_data or custom_holiday):
87+
raise ValueError("Cannot specify both `query_statement` and (`training_data` or `custom_holiday`).")
88+
89+
if query_statement:
90+
ddl += f"AS {query_statement}"
91+
elif training_data:
92+
# specialized AS clause
93+
parts = []
94+
parts.append(f"training_data AS ({training_data})")
95+
if custom_holiday:
96+
parts.append(f"custom_holiday AS ({custom_holiday})")
97+
98+
ddl += f"AS (\n {', '.join(parts)}\n)"
99+
100+
return ddl
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE MODEL `my_project.my_dataset.my_model`
2+
OPTIONS(model_type = 'LINEAR_REG', input_label_cols = ['label'])
3+
AS SELECT * FROM my_table
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE MODEL IF NOT EXISTS `my_model`
2+
OPTIONS(model_type = 'KMEANS')
3+
AS SELECT * FROM t
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE MODEL `my_model`
2+
OPTIONS(hidden_units = [32, 16], dropout = 0.2)
3+
AS SELECT * FROM t
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
CREATE MODEL `my_remote_model`
2+
INPUT (prompt STRING)
3+
OUTPUT (content STRING)
4+
REMOTE WITH CONNECTION `my_project.us.my_connection`
5+
OPTIONS(endpoint = 'gemini-pro')
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE MODEL `my_remote_model`
2+
REMOTE WITH CONNECTION DEFAULT
3+
OPTIONS(endpoint = 'gemini-pro')
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CREATE OR REPLACE MODEL `my_model`
2+
OPTIONS(model_type = 'LOGISTIC_REG')
3+
AS SELECT * FROM t

0 commit comments

Comments
 (0)