Skip to content

Commit 82d1aec

Browse files
committed
support pandas inputs
1 parent 68e770b commit 82d1aec

File tree

2 files changed

+419
-117
lines changed

2 files changed

+419
-117
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Mapping, Optional, Union
17+
from typing import cast, Mapping, Optional, Union
1818

1919
import bigframes_vendored.constants
2020
import google.cloud.bigquery
@@ -28,36 +28,44 @@
2828

2929

3030
# Helper to convert DataFrame to SQL string
31-
def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
31+
def _to_sql(df_or_sql: Union[pd.DataFrame, dataframe.DataFrame, str]) -> str:
32+
import bigframes.pandas as bpd
33+
3234
if isinstance(df_or_sql, str):
3335
return df_or_sql
34-
# It's a DataFrame
35-
sql, _, _ = df_or_sql._to_sql_query(include_index=False)
36+
37+
if isinstance(df_or_sql, pd.DataFrame):
38+
bf_df = bpd.read_pandas(df_or_sql)
39+
else:
40+
bf_df = cast(dataframe.DataFrame, df_or_sql)
41+
42+
sql, _, _ = bf_df._to_sql_query(include_index=False)
3643
return sql
3744

3845

3946
def _get_model_name_and_session(
40-
model: Union[bigframes.ml.base.BaseEstimator, str],
47+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
4148
# Other dataframe arguments to extract session from
42-
*dataframes: Optional[Union[dataframe.DataFrame, str]],
43-
) -> tuple[str, bigframes.session.Session]:
44-
import bigframes.pandas as bpd
45-
46-
if isinstance(model, str):
49+
*dataframes: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]],
50+
) -> tuple[str, Optional[bigframes.session.Session]]:
51+
if isinstance(model, pd.Series):
52+
model_ref = model["modelReference"]
53+
model_name = f"{model_ref['projectId']}.{model_ref['datasetId']}.{model_ref['modelId']}" # type: ignore
54+
elif isinstance(model, str):
4755
model_name = model
48-
session = None
49-
for df in dataframes:
50-
if isinstance(df, dataframe.DataFrame):
51-
session = df._session
52-
break
53-
if session is None:
54-
session = bpd.get_global_session()
55-
return model_name, session
5656
else:
5757
if model._bqml_model is None:
5858
raise ValueError("Model must be fitted to be used in ML operations.")
5959
return model._bqml_model.model_name, model._bqml_model.session
6060

61+
session = None
62+
for df in dataframes:
63+
if isinstance(df, dataframe.DataFrame):
64+
session = df._session
65+
break
66+
67+
return model_name, session
68+
6169

6270
def _get_model_metadata(
6371
*,
@@ -82,8 +90,8 @@ def create_model(
8290
output_schema: Optional[Mapping[str, str]] = None,
8391
connection_name: Optional[str] = None,
8492
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
85-
training_data: Optional[Union[dataframe.DataFrame, str]] = None,
86-
custom_holiday: Optional[Union[dataframe.DataFrame, str]] = None,
93+
training_data: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
94+
custom_holiday: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
8795
session: Optional[bigframes.session.Session] = None,
8896
) -> pd.Series:
8997
"""
@@ -169,8 +177,8 @@ def create_model(
169177

170178
@log_adapter.method_logger(custom_base_name="bigquery_ml")
171179
def evaluate(
172-
model: Union[bigframes.ml.base.BaseEstimator, str],
173-
input_: Optional[Union[dataframe.DataFrame, str]] = None,
180+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
181+
input_: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
174182
*,
175183
perform_aggregation: Optional[bool] = None,
176184
horizon: Optional[int] = None,
@@ -210,6 +218,8 @@ def evaluate(
210218
bigframes.pandas.DataFrame:
211219
The evaluation results.
212220
"""
221+
import bigframes.pandas as bpd
222+
213223
model_name, session = _get_model_name_and_session(model, input_)
214224
table_sql = _to_sql(input_) if input_ is not None else None
215225

@@ -221,13 +231,16 @@ def evaluate(
221231
confidence_level=confidence_level,
222232
)
223233

224-
return session.read_gbq(sql)
234+
if session is None:
235+
return bpd.read_gbq_query(sql)
236+
else:
237+
return session.read_gbq_query(sql)
225238

226239

227240
@log_adapter.method_logger(custom_base_name="bigquery_ml")
228241
def predict(
229-
model: Union[bigframes.ml.base.BaseEstimator, str],
230-
input_: Union[dataframe.DataFrame, str],
242+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
243+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
231244
*,
232245
threshold: Optional[float] = None,
233246
keep_original_columns: Optional[bool] = None,
@@ -259,6 +272,8 @@ def predict(
259272
bigframes.pandas.DataFrame:
260273
The prediction results.
261274
"""
275+
import bigframes.pandas as bpd
276+
262277
model_name, session = _get_model_name_and_session(model, input_)
263278
table_sql = _to_sql(input_)
264279

@@ -270,13 +285,16 @@ def predict(
270285
trial_id=trial_id,
271286
)
272287

273-
return session.read_gbq(sql)
288+
if session is None:
289+
return bpd.read_gbq_query(sql)
290+
else:
291+
return session.read_gbq_query(sql)
274292

275293

276294
@log_adapter.method_logger(custom_base_name="bigquery_ml")
277295
def explain_predict(
278-
model: Union[bigframes.ml.base.BaseEstimator, str],
279-
input_: Union[dataframe.DataFrame, str],
296+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
297+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
280298
*,
281299
top_k_features: Optional[int] = None,
282300
threshold: Optional[float] = None,
@@ -313,6 +331,8 @@ def explain_predict(
313331
bigframes.pandas.DataFrame:
314332
The prediction results with explanations.
315333
"""
334+
import bigframes.pandas as bpd
335+
316336
model_name, session = _get_model_name_and_session(model, input_)
317337
table_sql = _to_sql(input_)
318338

@@ -325,12 +345,15 @@ def explain_predict(
325345
approx_feature_contrib=approx_feature_contrib,
326346
)
327347

328-
return session.read_gbq(sql)
348+
if session is None:
349+
return bpd.read_gbq_query(sql)
350+
else:
351+
return session.read_gbq_query(sql)
329352

330353

331354
@log_adapter.method_logger(custom_base_name="bigquery_ml")
332355
def global_explain(
333-
model: Union[bigframes.ml.base.BaseEstimator, str],
356+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
334357
*,
335358
class_level_explain: Optional[bool] = None,
336359
) -> dataframe.DataFrame:
@@ -351,10 +374,15 @@ def global_explain(
351374
bigframes.pandas.DataFrame:
352375
The global explanation results.
353376
"""
377+
import bigframes.pandas as bpd
378+
354379
model_name, session = _get_model_name_and_session(model)
355380
sql = bigframes.core.sql.ml.global_explain(
356381
model_name=model_name,
357382
class_level_explain=class_level_explain,
358383
)
359384

360-
return session.read_gbq(sql)
385+
if session is None:
386+
return bpd.read_gbq_query(sql)
387+
else:
388+
return session.read_gbq_query(sql)

0 commit comments

Comments
 (0)