Skip to content

Commit 68e770b

Browse files
committed
return pd.Series from create_model
1 parent 74d4fcc commit 68e770b

File tree

6 files changed

+356
-117
lines changed

6 files changed

+356
-117
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
from typing import Mapping, Optional, Union
1818

19+
import bigframes_vendored.constants
20+
import google.cloud.bigquery
21+
import pandas as pd
22+
1923
import bigframes.core.log_adapter as log_adapter
2024
import bigframes.core.sql.ml
2125
import bigframes.dataframe as dataframe
@@ -55,6 +59,16 @@ def _get_model_name_and_session(
5559
return model._bqml_model.model_name, model._bqml_model.session
5660

5761

62+
def _get_model_metadata(
63+
*,
64+
bqclient: google.cloud.bigquery.Client,
65+
model_name: str,
66+
) -> pd.Series:
67+
model_metadata = bqclient.get_model(model_name)
68+
model_dict = model_metadata.to_api_repr()
69+
return pd.Series(model_dict)
70+
71+
5872
@log_adapter.method_logger(custom_base_name="bigquery_ml")
5973
def create_model(
6074
model_name: str,
@@ -71,7 +85,7 @@ def create_model(
7185
training_data: Optional[Union[dataframe.DataFrame, str]] = None,
7286
custom_holiday: Optional[Union[dataframe.DataFrame, str]] = None,
7387
session: Optional[bigframes.session.Session] = None,
74-
) -> bigframes.ml.base.BaseEstimator:
88+
) -> pd.Series:
7589
"""
7690
Creates a BigQuery ML model.
7791
@@ -105,8 +119,12 @@ def create_model(
105119
The session to use. If not provided, the default session is used.
106120
107121
Returns:
108-
bigframes.ml.base.BaseEstimator:
109-
The created BigQuery ML model.
122+
pandas.Series:
123+
A Series with object dtype containing the model metadata. Reference
124+
the `BigQuery Model REST API reference
125+
<https://docs.cloud.google.com/bigquery/docs/reference/rest/v2/models#Model>`_
126+
for available fields.
127+
110128
"""
111129
import bigframes.pandas as bpd
112130

@@ -138,12 +156,15 @@ def create_model(
138156
)
139157

140158
if session is None:
159+
bpd.read_gbq_query(sql)
141160
session = bpd.get_global_session()
161+
assert (
162+
session is not None
163+
), f"Missing connection to BigQuery. Please report how you encountered this error at {bigframes_vendored.constants.FEEDBACK_LINK}."
164+
else:
165+
session.read_gbq_query(sql)
142166

143-
# Use _start_query_ml_ddl which is designed for this
144-
session._start_query_ml_ddl(sql)
145-
146-
return session.read_gbq_model(model_name)
167+
return _get_model_metadata(bqclient=session.bqclient, model_name=model_name)
147168

148169

149170
@log_adapter.method_logger(custom_base_name="bigquery_ml")

bigframes/ml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
<https://docs.cloud.google.com/bigquery/docs/linear-regression-tutorial>`_ for a
3838
detailed example.
3939
40-
See all, the references for ``bigframes.ml`` sub-modules:
40+
See also the references for ``bigframes.ml`` sub-modules:
4141
4242
* :mod:`bigframes.ml.cluster`
4343
* :mod:`bigframes.ml.compose`

docs/reference/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ scikit-learn.
2828
:toctree: api
2929

3030
bigframes.ml
31-
bigframes.ml.base
3231
bigframes.ml.cluster
3332
bigframes.ml.compose
3433
bigframes.ml.decomposition

notebooks/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
.ipynb_checkpoints/
2+
*.bq_exec_time_seconds
3+
*.bytesprocessed
4+
*.query_char_count
5+
*.slotmillis

0 commit comments

Comments
 (0)