Skip to content

Commit 61200bd

Browse files
authored
refactor: push down SQL generate logic in core.BqmlModel (#66)
1 parent 392113b commit 61200bd

File tree

3 files changed

+80
-55
lines changed

3 files changed

+80
-55
lines changed

bigframes/ml/core.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def model(self) -> bigquery.Model:
5858
def _apply_sql(
5959
self,
6060
input_data: bpd.DataFrame,
61-
func: Callable[[str], str],
61+
func: Callable[[bpd.DataFrame], str],
6262
) -> bpd.DataFrame:
6363
"""Helper to wrap a dataframe in a SQL query, keeping the index intact.
6464
@@ -74,11 +74,9 @@ def _apply_sql(
7474
string from which to construct the output dataframe. It must
7575
include the index columns of the input SQL.
7676
"""
77-
source_sql, index_col_ids, index_labels = input_data._to_sql_query(
78-
include_index=True
79-
)
77+
_, index_col_ids, index_labels = input_data._to_sql_query(include_index=True)
8078

81-
sql = func(source_sql)
79+
sql = func(input_data)
8280
df = self._session.read_gbq(sql, index_col=index_col_ids)
8381
df.index.names = index_labels
8482

@@ -106,11 +104,9 @@ def generate_text(
106104
# TODO: validate input data schema
107105
return self._apply_sql(
108106
input_data,
109-
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_text(
110-
source_sql=source_sql,
111-
struct_options=self._model_manipulation_sql_generator.struct_options(
112-
**options
113-
),
107+
lambda source_df: self._model_manipulation_sql_generator.ml_generate_text(
108+
source_df=source_df,
109+
struct_options=options,
114110
),
115111
)
116112

@@ -122,11 +118,9 @@ def generate_text_embedding(
122118
# TODO: validate input data schema
123119
return self._apply_sql(
124120
input_data,
125-
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_text_embedding(
126-
source_sql=source_sql,
127-
struct_options=self._model_manipulation_sql_generator.struct_options(
128-
**options
129-
),
121+
lambda source_df: self._model_manipulation_sql_generator.ml_generate_text_embedding(
122+
source_df=source_df,
123+
struct_options=options,
130124
),
131125
)
132126

@@ -136,13 +130,7 @@ def forecast(self) -> bpd.DataFrame:
136130

137131
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
138132
# TODO: validate input data schema
139-
# Note: don't need index as evaluate returns a new table
140-
source_sql, _, _ = (
141-
input_data._to_sql_query(include_index=False)
142-
if (input_data is not None)
143-
else (None, None, None)
144-
)
145-
sql = self._model_manipulation_sql_generator.ml_evaluate(source_sql)
133+
sql = self._model_manipulation_sql_generator.ml_evaluate(input_data)
146134

147135
return self._session.read_gbq(sql)
148136

@@ -188,11 +176,8 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
188176
# truncate as Vertex ID only accepts 63 characters, easily exceeding the limit for temp models.
189177
# The possibility of conflicts should be low.
190178
vertex_ai_model_id = vertex_ai_model_id[:63]
191-
options_sql = self._model_manipulation_sql_generator.options(
192-
**{"vertex_ai_model_id": vertex_ai_model_id}
193-
)
194179
sql = self._model_manipulation_sql_generator.alter_model(
195-
options_sql=options_sql
180+
options={"vertex_ai_model_id": vertex_ai_model_id}
196181
)
197182
# Register the model and wait it to finish
198183
self._session._start_query(sql)
@@ -252,7 +237,7 @@ def create_model(
252237
session = X_train._session
253238

254239
sql = self._model_creation_sql_generator.create_model(
255-
source=input_data,
240+
source_df=input_data,
256241
transforms=transforms,
257242
options=options,
258243
)
@@ -281,7 +266,7 @@ def create_time_series_model(
281266
session = X_train._session
282267

283268
sql = self._model_creation_sql_generator.create_model(
284-
source=input_data,
269+
source_df=input_data,
285270
transforms=transforms,
286271
options=options,
287272
)

bigframes/ml/sql.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ def __init__(self, model_id: str):
118118
# Model create and alter
119119
def create_model(
120120
self,
121-
source: bpd.DataFrame,
121+
source_df: bpd.DataFrame,
122122
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
123123
transforms: Optional[Iterable[str]] = None,
124124
) -> str:
125125
"""Encode the CREATE TEMP MODEL statement for BQML"""
126-
source_sql = source.sql
126+
source_sql = source_df.sql
127127
transform_sql = self.transform(*transforms) if transforms is not None else None
128128
options_sql = self.options(**options)
129129

@@ -168,39 +168,58 @@ class ModelManipulationSqlGenerator(BaseSqlGenerator):
168168
def __init__(self, model_name: str):
169169
self._model_name = model_name
170170

171+
def _source_sql(self, source_df: bpd.DataFrame) -> str:
172+
"""Return DataFrame sql with index columns."""
173+
_source_sql, _, _ = source_df._to_sql_query(include_index=True)
174+
return _source_sql
175+
171176
# Alter model
172177
def alter_model(
173178
self,
174-
options_sql: str,
179+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
175180
) -> str:
176181
"""Encode the ALTER MODEL statement for BQML"""
182+
options_sql = self.options(**options)
183+
177184
parts = [f"ALTER MODEL `{self._model_name}`"]
178185
parts.append(f"SET {options_sql}")
179186
return "\n".join(parts)
180187

181188
# ML prediction TVFs
182-
def ml_predict(self, source_sql: str) -> str:
189+
def ml_predict(self, source_df: bpd.DataFrame) -> str:
183190
"""Encode ML.PREDICT for BQML"""
184191
return f"""SELECT * FROM ML.PREDICT(MODEL `{self._model_name}`,
185-
({source_sql}))"""
192+
({self._source_sql(source_df)}))"""
186193

187194
def ml_forecast(self) -> str:
188195
"""Encode ML.FORECAST for BQML"""
189196
return f"""SELECT * FROM ML.FORECAST(MODEL `{self._model_name}`)"""
190197

191-
def ml_generate_text(self, source_sql: str, struct_options: str) -> str:
198+
def ml_generate_text(
199+
self, source_df: bpd.DataFrame, struct_options: Mapping[str, Union[int, float]]
200+
) -> str:
192201
"""Encode ML.GENERATE_TEXT for BQML"""
202+
struct_options_sql = self.struct_options(**struct_options)
193203
return f"""SELECT * FROM ML.GENERATE_TEXT(MODEL `{self._model_name}`,
194-
({source_sql}), {struct_options})"""
204+
({self._source_sql(source_df)}), {struct_options_sql})"""
195205

196-
def ml_generate_text_embedding(self, source_sql: str, struct_options: str) -> str:
206+
def ml_generate_text_embedding(
207+
self, source_df: bpd.DataFrame, struct_options: Mapping[str, Union[int, float]]
208+
) -> str:
197209
"""Encode ML.GENERATE_TEXT_EMBEDDING for BQML"""
210+
struct_options_sql = self.struct_options(**struct_options)
198211
return f"""SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `{self._model_name}`,
199-
({source_sql}), {struct_options})"""
212+
({self._source_sql(source_df)}), {struct_options_sql})"""
200213

201214
# ML evaluation TVFs
202-
def ml_evaluate(self, source_sql: Optional[str] = None) -> str:
215+
def ml_evaluate(self, source_df: Optional[bpd.DataFrame] = None) -> str:
203216
"""Encode ML.EVALUATE for BQML"""
217+
if source_df is None:
218+
source_sql = None
219+
else:
220+
# Note: don't need index as evaluate returns a new table
221+
source_sql, _, _ = source_df._to_sql_query(include_index=False)
222+
204223
if source_sql is None:
205224
return f"""SELECT * FROM ML.EVALUATE(MODEL `{self._model_name}`)"""
206225
else:
@@ -222,7 +241,7 @@ def ml_principal_component_info(self) -> str:
222241
)
223242

224243
# ML transform TVF, that require a transform_only type model
225-
def ml_transform(self, source_sql: str) -> str:
244+
def ml_transform(self, source_df: bpd.DataFrame) -> str:
226245
"""Encode ML.TRANSFORM for BQML"""
227246
return f"""SELECT * FROM ML.TRANSFORM(MODEL `{self._model_name}`,
228-
({source_sql}))"""
247+
({self._source_sql(source_df)}))"""

tests/unit/ml/test_sql.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def model_manipulation_sql_generator() -> ml_sql.ModelManipulationSqlGenerator:
4141
def mock_df():
4242
mock_df = mock.create_autospec(spec=bpd.DataFrame)
4343
mock_df.sql = "input_X_y_sql"
44+
mock_df._to_sql_query.return_value = "input_X_sql", None, None
4445

4546
return mock_df
4647

@@ -117,7 +118,7 @@ def test_create_model_produces_correct_sql(
117118
mock_df: bpd.DataFrame,
118119
):
119120
sql = model_creation_sql_generator.create_model(
120-
source=mock_df,
121+
source_df=mock_df,
121122
options={"option_key1": "option_value1", "option_key2": 2},
122123
)
123124
assert (
@@ -135,7 +136,7 @@ def test_create_model_transform_produces_correct_sql(
135136
mock_df: bpd.DataFrame,
136137
):
137138
sql = model_creation_sql_generator.create_model(
138-
source=mock_df,
139+
source_df=mock_df,
139140
options={"option_key1": "option_value1", "option_key2": 2},
140141
transforms=[
141142
"ML.STANDARD_SCALER(col_a) OVER(col_a) AS scaled_col_a",
@@ -191,38 +192,38 @@ def test_alter_model_correct_sql(
191192
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
192193
):
193194
sql = model_manipulation_sql_generator.alter_model(
194-
options_sql="my_options_sql",
195+
options={"option_key1": "option_value1", "option_key2": 2},
195196
)
196197
assert (
197198
sql
198199
== """ALTER MODEL `my_project_id.my_dataset_id.my_model_id`
199-
SET my_options_sql"""
200+
SET OPTIONS(
201+
option_key1="option_value1",
202+
option_key2=2)"""
200203
)
201204

202205

203206
def test_ml_predict_produces_correct_sql(
204207
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
208+
mock_df: bpd.DataFrame,
205209
):
206-
sql = model_manipulation_sql_generator.ml_predict(
207-
source_sql="SELECT * FROM my_table"
208-
)
210+
sql = model_manipulation_sql_generator.ml_predict(source_df=mock_df)
209211
assert (
210212
sql
211213
== """SELECT * FROM ML.PREDICT(MODEL `my_project_id.my_dataset_id.my_model_id`,
212-
(SELECT * FROM my_table))"""
214+
(input_X_sql))"""
213215
)
214216

215217

216218
def test_ml_evaluate_produces_correct_sql(
217219
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
220+
mock_df: bpd.DataFrame,
218221
):
219-
sql = model_manipulation_sql_generator.ml_evaluate(
220-
source_sql="SELECT * FROM my_table"
221-
)
222+
sql = model_manipulation_sql_generator.ml_evaluate(source_df=mock_df)
222223
assert (
223224
sql
224225
== """SELECT * FROM ML.EVALUATE(MODEL `my_project_id.my_dataset_id.my_model_id`,
225-
(SELECT * FROM my_table))"""
226+
(input_X_sql))"""
226227
)
227228

228229

@@ -248,15 +249,35 @@ def test_ml_centroids_produces_correct_sql(
248249

249250
def test_ml_generate_text_produces_correct_sql(
250251
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
252+
mock_df: bpd.DataFrame,
251253
):
252254
sql = model_manipulation_sql_generator.ml_generate_text(
253-
source_sql="SELECT * FROM my_table",
254-
struct_options="STRUCT(value AS item)",
255+
source_df=mock_df,
256+
struct_options={"option_key1": 1, "option_key2": 2.2},
255257
)
256258
assert (
257259
sql
258260
== """SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project_id.my_dataset_id.my_model_id`,
259-
(SELECT * FROM my_table), STRUCT(value AS item))"""
261+
(input_X_sql), STRUCT(
262+
1 AS option_key1,
263+
2.2 AS option_key2))"""
264+
)
265+
266+
267+
def test_ml_generate_text_embedding_produces_correct_sql(
268+
model_manipulation_sql_generator: ml_sql.ModelManipulationSqlGenerator,
269+
mock_df: bpd.DataFrame,
270+
):
271+
sql = model_manipulation_sql_generator.ml_generate_text_embedding(
272+
source_df=mock_df,
273+
struct_options={"option_key1": 1, "option_key2": 2.2},
274+
)
275+
assert (
276+
sql
277+
== """SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(MODEL `my_project_id.my_dataset_id.my_model_id`,
278+
(input_X_sql), STRUCT(
279+
1 AS option_key1,
280+
2.2 AS option_key2))"""
260281
)
261282

262283

0 commit comments

Comments
 (0)