Skip to content

Commit 14b262b

Browse files
authored
feat: add ml.preprocessing.MaxAbsScaler (#56)
1 parent 416d7cb commit 14b262b

File tree

12 files changed

+370
-71
lines changed

12 files changed

+370
-71
lines changed

bigframes/clients.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import logging
2020
import time
21-
from typing import Optional
21+
from typing import cast, Optional
2222

2323
import google.api_core.exceptions
2424
from google.cloud import bigquery_connection_v1, resourcemanager_v3
@@ -80,6 +80,7 @@ def create_bq_connection(
8080
logger.info(
8181
f"Created BQ connection {connection_name} with service account id: {service_account_id}"
8282
)
83+
service_account_id = cast(str, service_account_id)
8384
# Ensure IAM role on the BQ connection
8485
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#grant_permission_on_function
8586
self._ensure_iam_binding(project_id, service_account_id, iam_role)

bigframes/ml/compose.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CompilablePreprocessorType = Union[
3030
preprocessing.OneHotEncoder,
3131
preprocessing.StandardScaler,
32+
preprocessing.MaxAbsScaler,
3233
preprocessing.LabelEncoder,
3334
]
3435

bigframes/ml/pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
5050
compose.ColumnTransformer,
5151
preprocessing.StandardScaler,
5252
preprocessing.OneHotEncoder,
53+
preprocessing.MaxAbsScaler,
5354
preprocessing.LabelEncoder,
5455
),
5556
):
@@ -147,6 +148,7 @@ def _extract_as_column_transformer(
147148
Union[
148149
preprocessing.OneHotEncoder,
149150
preprocessing.StandardScaler,
151+
preprocessing.MaxAbsScaler,
150152
preprocessing.LabelEncoder,
151153
],
152154
Union[str, List[str]],
@@ -172,6 +174,13 @@ def _extract_as_column_transformer(
172174
*preprocessing.OneHotEncoder._parse_from_sql(transform_sql),
173175
)
174176
)
177+
elif transform_sql.startswith("ML.MAX_ABS_SCALER"):
178+
transformers.append(
179+
(
180+
"max_abs_encoder",
181+
*preprocessing.MaxAbsScaler._parse_from_sql(transform_sql),
182+
)
183+
)
175184
elif transform_sql.startswith("ML.LABEL_ENCODER"):
176185
transformers.append(
177186
(
@@ -193,6 +202,7 @@ def _merge_column_transformer(
193202
compose.ColumnTransformer,
194203
preprocessing.StandardScaler,
195204
preprocessing.OneHotEncoder,
205+
preprocessing.MaxAbsScaler,
196206
preprocessing.LabelEncoder,
197207
]:
198208
"""Try to merge the column transformer to a simple transformer."""

bigframes/ml/preprocessing.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
5454
Returns: a list of tuples of (sql_expression, output_name)"""
5555
return [
5656
(
57-
self._base_sql_generator.ml_standard_scaler(column, f"scaled_{column}"),
58-
f"scaled_{column}",
57+
self._base_sql_generator.ml_standard_scaler(
58+
column, f"standard_scaled_{column}"
59+
),
60+
f"standard_scaled_{column}",
5961
)
6062
for column in columns
6163
]
@@ -105,6 +107,86 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
105107
)
106108

107109

110+
class MaxAbsScaler(
111+
base.Transformer,
112+
third_party.bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler,
113+
):
114+
__doc__ = (
115+
third_party.bigframes_vendored.sklearn.preprocessing._data.MaxAbsScaler.__doc__
116+
)
117+
118+
def __init__(self):
119+
self._bqml_model: Optional[core.BqmlModel] = None
120+
self._bqml_model_factory = globals.bqml_model_factory()
121+
self._base_sql_generator = globals.base_sql_generator()
122+
123+
# TODO(garrettwu): implement __hash__
124+
def __eq__(self, other: Any) -> bool:
125+
return type(other) is MaxAbsScaler and self._bqml_model == other._bqml_model
126+
127+
def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
128+
"""Compile this transformer to a list of SQL expressions that can be included in
129+
a BQML TRANSFORM clause
130+
131+
Args:
132+
columns: a list of column names to transform
133+
134+
Returns: a list of tuples of (sql_expression, output_name)"""
135+
return [
136+
(
137+
self._base_sql_generator.ml_max_abs_scaler(
138+
column, f"max_abs_scaled_{column}"
139+
),
140+
f"max_abs_scaled_{column}",
141+
)
142+
for column in columns
143+
]
144+
145+
@classmethod
146+
def _parse_from_sql(cls, sql: str) -> tuple[MaxAbsScaler, str]:
147+
"""Parse SQL to tuple(StandardScaler, column_label).
148+
149+
Args:
150+
sql: SQL string of format "ML.MAX_ABS_SCALER({col_label}) OVER()"
151+
152+
Returns:
153+
tuple(StandardScaler, column_label)"""
154+
col_label = sql[sql.find("(") + 1 : sql.find(")")]
155+
return cls(), col_label
156+
157+
def fit(
158+
self,
159+
X: Union[bpd.DataFrame, bpd.Series],
160+
y=None, # ignored
161+
) -> MaxAbsScaler:
162+
(X,) = utils.convert_to_dataframe(X)
163+
164+
compiled_transforms = self._compile_to_sql(X.columns.tolist())
165+
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
166+
167+
self._bqml_model = self._bqml_model_factory.create_model(
168+
X,
169+
options={"model_type": "transform_only"},
170+
transforms=transform_sqls,
171+
)
172+
173+
# The schema of TRANSFORM output is not available in the model API, so save it during fitting
174+
self._output_names = [name for _, name in compiled_transforms]
175+
return self
176+
177+
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
178+
if not self._bqml_model:
179+
raise RuntimeError("Must be fitted before transform")
180+
181+
(X,) = utils.convert_to_dataframe(X)
182+
183+
df = self._bqml_model.transform(X)
184+
return typing.cast(
185+
bpd.DataFrame,
186+
df[self._output_names],
187+
)
188+
189+
108190
class OneHotEncoder(
109191
base.Transformer,
110192
third_party.bigframes_vendored.sklearn.preprocessing._encoder.OneHotEncoder,

bigframes/ml/sql.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def ml_standard_scaler(self, numeric_expr_sql: str, name: str) -> str:
7676
"""Encode ML.STANDARD_SCALER for BQML"""
7777
return f"""ML.STANDARD_SCALER({numeric_expr_sql}) OVER() AS {name}"""
7878

79+
def ml_max_abs_scaler(self, numeric_expr_sql: str, name: str) -> str:
80+
"""Encode ML.MAX_ABS_SCALER for BQML"""
81+
return f"""ML.MAX_ABS_SCALER({numeric_expr_sql}) OVER() AS {name}"""
82+
7983
def ml_one_hot_encoder(
8084
self,
8185
numeric_expr_sql: str,

tests/system/large/ml/test_compose.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,20 @@ def test_columntransformer_standalone_fit_and_transform(
5656
[{"index": 1, "value": 1.0}],
5757
[{"index": 2, "value": 1.0}],
5858
],
59-
"scaled_culmen_length_mm": [
59+
"standard_scaled_culmen_length_mm": [
6060
-0.811119671289163,
6161
-0.9945520581113803,
6262
-1.104611490204711,
6363
],
64-
"scaled_flipper_length_mm": [-0.350044, -1.418336, -0.9198],
64+
"standard_scaled_flipper_length_mm": [-0.350044, -1.418336, -0.9198],
6565
},
6666
index=pandas.Index([1633, 1672, 1690], dtype="Int64", name="tag_number"),
6767
)
68-
expected.scaled_culmen_length_mm = expected.scaled_culmen_length_mm.astype(
69-
"Float64"
68+
expected.standard_scaled_culmen_length_mm = (
69+
expected.standard_scaled_culmen_length_mm.astype("Float64")
7070
)
71-
expected.scaled_flipper_length_mm = expected.scaled_flipper_length_mm.astype(
72-
"Float64"
71+
expected.standard_scaled_flipper_length_mm = (
72+
expected.standard_scaled_flipper_length_mm.astype("Float64")
7373
)
7474

7575
pandas.testing.assert_frame_equal(result, expected, rtol=1e-3)
@@ -107,20 +107,20 @@ def test_columntransformer_standalone_fit_transform(new_penguins_df):
107107
[{"index": 1, "value": 1.0}],
108108
[{"index": 2, "value": 1.0}],
109109
],
110-
"scaled_culmen_length_mm": [
110+
"standard_scaled_culmen_length_mm": [
111111
1.313249,
112112
-0.20198,
113113
-1.111118,
114114
],
115-
"scaled_flipper_length_mm": [1.251098, -1.196588, -0.054338],
115+
"standard_scaled_flipper_length_mm": [1.251098, -1.196588, -0.054338],
116116
},
117117
index=pandas.Index([1633, 1672, 1690], dtype="Int64", name="tag_number"),
118118
)
119-
expected.scaled_culmen_length_mm = expected.scaled_culmen_length_mm.astype(
120-
"Float64"
119+
expected.standard_scaled_culmen_length_mm = (
120+
expected.standard_scaled_culmen_length_mm.astype("Float64")
121121
)
122-
expected.scaled_flipper_length_mm = expected.scaled_flipper_length_mm.astype(
123-
"Float64"
122+
expected.standard_scaled_flipper_length_mm = (
123+
expected.standard_scaled_flipper_length_mm.astype("Float64")
124124
)
125125

126126
pandas.testing.assert_frame_equal(result, expected, rtol=1e-3)

tests/system/large/ml/test_pipeline.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -566,10 +566,15 @@ def test_pipeline_columntransformer_fit_predict(session, penguins_df_default_ind
566566
"species",
567567
),
568568
(
569-
"scale",
569+
"standard_scale",
570570
preprocessing.StandardScaler(),
571571
["culmen_length_mm", "flipper_length_mm"],
572572
),
573+
(
574+
"max_abs_scale",
575+
preprocessing.MaxAbsScaler(),
576+
["culmen_length_mm", "flipper_length_mm"],
577+
),
573578
(
574579
"label",
575580
preprocessing.LabelEncoder(),
@@ -637,6 +642,11 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id
637642
preprocessing.StandardScaler(),
638643
["culmen_length_mm", "flipper_length_mm"],
639644
),
645+
(
646+
"max_abs_scale",
647+
preprocessing.MaxAbsScaler(),
648+
["culmen_length_mm", "flipper_length_mm"],
649+
),
640650
(
641651
"label",
642652
preprocessing.LabelEncoder(),
@@ -660,30 +670,26 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id
660670

661671
assert isinstance(pl_loaded._transform, compose.ColumnTransformer)
662672
transformers = pl_loaded._transform.transformers_
663-
assert len(transformers) == 4
664-
665-
assert transformers[0][0] == "ont_hot_encoder"
666-
assert isinstance(transformers[0][1], preprocessing.OneHotEncoder)
667-
one_hot_encoder = transformers[0][1]
668-
assert one_hot_encoder.drop == "most_frequent"
669-
assert one_hot_encoder.min_frequency == 5
670-
assert one_hot_encoder.max_categories == 100
671-
assert transformers[0][2] == "species"
672-
673-
assert transformers[1][0] == "label_encoder"
674-
assert isinstance(transformers[1][1], preprocessing.LabelEncoder)
675-
one_hot_encoder = transformers[1][1]
676-
assert one_hot_encoder.min_frequency == 0
677-
assert one_hot_encoder.max_categories == 1000001
678-
assert transformers[1][2] == "species"
679-
680-
assert transformers[2][0] == "standard_scaler"
681-
assert isinstance(transformers[2][1], preprocessing.StandardScaler)
682-
assert transformers[2][2] == "culmen_length_mm"
673+
expected = [
674+
(
675+
"ont_hot_encoder",
676+
preprocessing.OneHotEncoder(
677+
drop="most_frequent", max_categories=100, min_frequency=5
678+
),
679+
"species",
680+
),
681+
(
682+
"label_encoder",
683+
preprocessing.LabelEncoder(max_categories=1000001, min_frequency=0),
684+
"species",
685+
),
686+
("standard_scaler", preprocessing.StandardScaler(), "culmen_length_mm"),
687+
("max_abs_encoder", preprocessing.MaxAbsScaler(), "culmen_length_mm"),
688+
("standard_scaler", preprocessing.StandardScaler(), "flipper_length_mm"),
689+
("max_abs_encoder", preprocessing.MaxAbsScaler(), "flipper_length_mm"),
690+
]
683691

684-
assert transformers[3][0] == "standard_scaler"
685-
assert isinstance(transformers[2][1], preprocessing.StandardScaler)
686-
assert transformers[3][2] == "flipper_length_mm"
692+
assert transformers == expected
687693

688694
assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
689695
assert pl_loaded._estimator.fit_intercept is False
@@ -717,6 +723,34 @@ def test_pipeline_standard_scaler_to_gbq(penguins_df_default_index, dataset_id):
717723
assert pl_loaded._estimator.fit_intercept is False
718724

719725

726+
def test_pipeline_max_abs_scaler_to_gbq(penguins_df_default_index, dataset_id):
727+
pl = pipeline.Pipeline(
728+
[
729+
("transform", preprocessing.MaxAbsScaler()),
730+
("estimator", linear_model.LinearRegression(fit_intercept=False)),
731+
]
732+
)
733+
734+
df = penguins_df_default_index.dropna()
735+
X_train = df[
736+
[
737+
"culmen_length_mm",
738+
"culmen_depth_mm",
739+
"flipper_length_mm",
740+
]
741+
]
742+
y_train = df[["body_mass_g"]]
743+
pl.fit(X_train, y_train)
744+
745+
pl_loaded = pl.to_gbq(
746+
f"{dataset_id}.test_penguins_pipeline_standard_scaler", replace=True
747+
)
748+
assert isinstance(pl_loaded._transform, preprocessing.MaxAbsScaler)
749+
750+
assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
751+
assert pl_loaded._estimator.fit_intercept is False
752+
753+
720754
def test_pipeline_one_hot_encoder_to_gbq(penguins_df_default_index, dataset_id):
721755
pl = pipeline.Pipeline(
722756
[

0 commit comments

Comments
 (0)