Skip to content

Commit 33274c2

Browse files
authored
refactor: ml.sql to Object (#44)
* refactor: ml.sql to Object Change-Id: Ibf795b81619778eaf28572fccd95a09b65f8ad58
1 parent 5e199ec commit 33274c2

File tree

16 files changed

+559
-410
lines changed

16 files changed

+559
-410
lines changed

bigframes/ml/cluster.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from google.cloud import bigquery
2323

2424
import bigframes
25-
from bigframes.ml import base, core, utils
25+
from bigframes.ml import base, core, globals, utils
2626
import bigframes.pandas as bpd
2727
import third_party.bigframes_vendored.sklearn.cluster._kmeans
2828

@@ -37,6 +37,7 @@ class KMeans(
3737
def __init__(self, n_clusters: int = 8):
3838
self.n_clusters = n_clusters
3939
self._bqml_model: Optional[core.BqmlModel] = None
40+
self._bqml_model_factory = globals.bqml_model_factory()
4041

4142
@classmethod
4243
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> KMeans:
@@ -66,7 +67,7 @@ def _fit(
6667
) -> KMeans:
6768
(X,) = utils.convert_to_dataframe(X)
6869

69-
self._bqml_model = core.create_bqml_model(
70+
self._bqml_model = self._bqml_model_factory.create_model(
7071
X_train=X,
7172
transforms=transforms,
7273
options=self._bqml_options,

bigframes/ml/compose.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import List, Optional, Tuple, Union
2323

2424
from bigframes import constants
25-
from bigframes.ml import base, core, preprocessing, utils
25+
from bigframes.ml import base, core, globals, preprocessing, utils
2626
import bigframes.pandas as bpd
2727
import third_party.bigframes_vendored.sklearn.compose._column_transformer
2828

@@ -53,6 +53,7 @@ def __init__(
5353
# TODO: if any(transformers) has fitted raise warning
5454
self.transformers = transformers
5555
self._bqml_model: Optional[core.BqmlModel] = None
56+
self._bqml_model_factory = globals.bqml_model_factory()
5657
# call self.transformers_ to check chained transformers
5758
self.transformers_
5859

@@ -114,7 +115,7 @@ def fit(
114115
compiled_transforms = self._compile_to_sql(X.columns.tolist())
115116
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
116117

117-
self._bqml_model = core.create_bqml_model(
118+
self._bqml_model = self._bqml_model_factory.create_model(
118119
X,
119120
options={"model_type": "transform_only"},
120121
transforms=transform_sqls,

0 commit comments

Comments
 (0)