1414
1515from __future__ import annotations
1616
17- from typing import Mapping , Optional , Union
17+ from typing import cast , Mapping , Optional , Union
1818
1919import bigframes_vendored .constants
2020import google .cloud .bigquery
2828
2929
3030# Helper to convert DataFrame to SQL string
31- def _to_sql (df_or_sql : Union [dataframe .DataFrame , str ]) -> str :
31+ def _to_sql (df_or_sql : Union [pd .DataFrame , dataframe .DataFrame , str ]) -> str :
32+ import bigframes .pandas as bpd
33+
3234 if isinstance (df_or_sql , str ):
3335 return df_or_sql
34- # It's a DataFrame
35- sql , _ , _ = df_or_sql ._to_sql_query (include_index = False )
36+
37+ if isinstance (df_or_sql , pd .DataFrame ):
38+ bf_df = bpd .read_pandas (df_or_sql )
39+ else :
40+ bf_df = cast (dataframe .DataFrame , df_or_sql )
41+
42+ sql , _ , _ = bf_df ._to_sql_query (include_index = False )
3643 return sql
3744
3845
3946def _get_model_name_and_session (
40- model : Union [bigframes .ml .base .BaseEstimator , str ],
47+ model : Union [bigframes .ml .base .BaseEstimator , str , pd . Series ],
4148 # Other dataframe arguments to extract session from
42- * dataframes : Optional [Union [dataframe .DataFrame , str ]],
43- ) -> tuple [str , bigframes .session .Session ]:
44- import bigframes .pandas as bpd
45-
46- if isinstance (model , str ):
49+ * dataframes : Optional [Union [pd .DataFrame , dataframe .DataFrame , str ]],
50+ ) -> tuple [str , Optional [bigframes .session .Session ]]:
51+ if isinstance (model , pd .Series ):
52+ model_ref = model ["modelReference" ]
53+ model_name = f"{ model_ref ['projectId' ]} .{ model_ref ['datasetId' ]} .{ model_ref ['modelId' ]} " # type: ignore
54+ elif isinstance (model , str ):
4755 model_name = model
48- session = None
49- for df in dataframes :
50- if isinstance (df , dataframe .DataFrame ):
51- session = df ._session
52- break
53- if session is None :
54- session = bpd .get_global_session ()
55- return model_name , session
5656 else :
5757 if model ._bqml_model is None :
5858 raise ValueError ("Model must be fitted to be used in ML operations." )
5959 return model ._bqml_model .model_name , model ._bqml_model .session
6060
61+ session = None
62+ for df in dataframes :
63+ if isinstance (df , dataframe .DataFrame ):
64+ session = df ._session
65+ break
66+
67+ return model_name , session
68+
6169
6270def _get_model_metadata (
6371 * ,
@@ -82,8 +90,8 @@ def create_model(
8290 output_schema : Optional [Mapping [str , str ]] = None ,
8391 connection_name : Optional [str ] = None ,
8492 options : Optional [Mapping [str , Union [str , int , float , bool , list ]]] = None ,
85- training_data : Optional [Union [dataframe .DataFrame , str ]] = None ,
86- custom_holiday : Optional [Union [dataframe .DataFrame , str ]] = None ,
93+ training_data : Optional [Union [pd . DataFrame , dataframe .DataFrame , str ]] = None ,
94+ custom_holiday : Optional [Union [pd . DataFrame , dataframe .DataFrame , str ]] = None ,
8795 session : Optional [bigframes .session .Session ] = None ,
8896) -> pd .Series :
8997 """
@@ -169,8 +177,8 @@ def create_model(
169177
170178@log_adapter .method_logger (custom_base_name = "bigquery_ml" )
171179def evaluate (
172- model : Union [bigframes .ml .base .BaseEstimator , str ],
173- input_ : Optional [Union [dataframe .DataFrame , str ]] = None ,
180+ model : Union [bigframes .ml .base .BaseEstimator , str , pd . Series ],
181+ input_ : Optional [Union [pd . DataFrame , dataframe .DataFrame , str ]] = None ,
174182 * ,
175183 perform_aggregation : Optional [bool ] = None ,
176184 horizon : Optional [int ] = None ,
@@ -210,6 +218,8 @@ def evaluate(
210218 bigframes.pandas.DataFrame:
211219 The evaluation results.
212220 """
221+ import bigframes .pandas as bpd
222+
213223 model_name , session = _get_model_name_and_session (model , input_ )
214224 table_sql = _to_sql (input_ ) if input_ is not None else None
215225
@@ -221,13 +231,16 @@ def evaluate(
221231 confidence_level = confidence_level ,
222232 )
223233
224- return session .read_gbq (sql )
234+ if session is None :
235+ return bpd .read_gbq_query (sql )
236+ else :
237+ return session .read_gbq_query (sql )
225238
226239
227240@log_adapter .method_logger (custom_base_name = "bigquery_ml" )
228241def predict (
229- model : Union [bigframes .ml .base .BaseEstimator , str ],
230- input_ : Union [dataframe .DataFrame , str ],
242+ model : Union [bigframes .ml .base .BaseEstimator , str , pd . Series ],
243+ input_ : Union [pd . DataFrame , dataframe .DataFrame , str ],
231244 * ,
232245 threshold : Optional [float ] = None ,
233246 keep_original_columns : Optional [bool ] = None ,
@@ -259,6 +272,8 @@ def predict(
259272 bigframes.pandas.DataFrame:
260273 The prediction results.
261274 """
275+ import bigframes .pandas as bpd
276+
262277 model_name , session = _get_model_name_and_session (model , input_ )
263278 table_sql = _to_sql (input_ )
264279
@@ -270,13 +285,16 @@ def predict(
270285 trial_id = trial_id ,
271286 )
272287
273- return session .read_gbq (sql )
288+ if session is None :
289+ return bpd .read_gbq_query (sql )
290+ else :
291+ return session .read_gbq_query (sql )
274292
275293
276294@log_adapter .method_logger (custom_base_name = "bigquery_ml" )
277295def explain_predict (
278- model : Union [bigframes .ml .base .BaseEstimator , str ],
279- input_ : Union [dataframe .DataFrame , str ],
296+ model : Union [bigframes .ml .base .BaseEstimator , str , pd . Series ],
297+ input_ : Union [pd . DataFrame , dataframe .DataFrame , str ],
280298 * ,
281299 top_k_features : Optional [int ] = None ,
282300 threshold : Optional [float ] = None ,
@@ -313,6 +331,8 @@ def explain_predict(
313331 bigframes.pandas.DataFrame:
314332 The prediction results with explanations.
315333 """
334+ import bigframes .pandas as bpd
335+
316336 model_name , session = _get_model_name_and_session (model , input_ )
317337 table_sql = _to_sql (input_ )
318338
@@ -325,12 +345,15 @@ def explain_predict(
325345 approx_feature_contrib = approx_feature_contrib ,
326346 )
327347
328- return session .read_gbq (sql )
348+ if session is None :
349+ return bpd .read_gbq_query (sql )
350+ else :
351+ return session .read_gbq_query (sql )
329352
330353
331354@log_adapter .method_logger (custom_base_name = "bigquery_ml" )
332355def global_explain (
333- model : Union [bigframes .ml .base .BaseEstimator , str ],
356+ model : Union [bigframes .ml .base .BaseEstimator , str , pd . Series ],
334357 * ,
335358 class_level_explain : Optional [bool ] = None ,
336359) -> dataframe .DataFrame :
@@ -351,10 +374,15 @@ def global_explain(
351374 bigframes.pandas.DataFrame:
352375 The global explanation results.
353376 """
377+ import bigframes .pandas as bpd
378+
354379 model_name , session = _get_model_name_and_session (model )
355380 sql = bigframes .core .sql .ml .global_explain (
356381 model_name = model_name ,
357382 class_level_explain = class_level_explain ,
358383 )
359384
360- return session .read_gbq (sql )
385+ if session is None :
386+ return bpd .read_gbq_query (sql )
387+ else :
388+ return session .read_gbq_query (sql )
0 commit comments