Skip to content

Commit 2510461

Browse files
authored
feat: add ml.preprocessing.LabelEncoder (#50)
1 parent 33274c2 commit 2510461

File tree

9 files changed

+415
-7
lines changed

9 files changed

+415
-7
lines changed

bigframes/ml/compose.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CompilablePreprocessorType = Union[
3030
preprocessing.OneHotEncoder,
3131
preprocessing.StandardScaler,
32+
preprocessing.LabelEncoder,
3233
]
3334

3435

bigframes/ml/pipeline.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, steps: List[Tuple[str, base.BaseEstimator]]):
5050
compose.ColumnTransformer,
5151
preprocessing.StandardScaler,
5252
preprocessing.OneHotEncoder,
53+
preprocessing.LabelEncoder,
5354
),
5455
):
5556
self._transform = transform
@@ -143,7 +144,11 @@ def _extract_as_column_transformer(
143144
transformers: List[
144145
Tuple[
145146
str,
146-
Union[preprocessing.OneHotEncoder, preprocessing.StandardScaler],
147+
Union[
148+
preprocessing.OneHotEncoder,
149+
preprocessing.StandardScaler,
150+
preprocessing.LabelEncoder,
151+
],
147152
Union[str, List[str]],
148153
]
149154
] = []
@@ -167,6 +172,13 @@ def _extract_as_column_transformer(
167172
*preprocessing.OneHotEncoder._parse_from_sql(transform_sql),
168173
)
169174
)
175+
elif transform_sql.startswith("ML.LABEL_ENCODER"):
176+
transformers.append(
177+
(
178+
"label_encoder",
179+
*preprocessing.LabelEncoder._parse_from_sql(transform_sql),
180+
)
181+
)
170182
else:
171183
raise NotImplementedError(
172184
f"Unsupported transformer type. {constants.FEEDBACK_LINK}"
@@ -181,6 +193,7 @@ def _merge_column_transformer(
181193
compose.ColumnTransformer,
182194
preprocessing.StandardScaler,
183195
preprocessing.OneHotEncoder,
196+
preprocessing.LabelEncoder,
184197
]:
185198
"""Try to merge the column transformer to a simple transformer."""
186199
transformers = column_transformer.transformers_

bigframes/ml/preprocessing.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import bigframes.pandas as bpd
2525
import third_party.bigframes_vendored.sklearn.preprocessing._data
2626
import third_party.bigframes_vendored.sklearn.preprocessing._encoder
27+
import third_party.bigframes_vendored.sklearn.preprocessing._label
2728

2829

2930
class StandardScaler(
@@ -229,3 +230,121 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
229230
bpd.DataFrame,
230231
df[self._output_names],
231232
)
233+
234+
235+
class LabelEncoder(
236+
base.Transformer,
237+
third_party.bigframes_vendored.sklearn.preprocessing._label.LabelEncoder,
238+
):
239+
# BQML max value https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-one-hot-encoder#syntax
240+
TOP_K_DEFAULT = 1000000
241+
FREQUENCY_THRESHOLD_DEFAULT = 0
242+
243+
__doc__ = (
244+
third_party.bigframes_vendored.sklearn.preprocessing._label.LabelEncoder.__doc__
245+
)
246+
247+
# All estimators must implement __init__ to document their parameters, even
248+
# if they don't have any
249+
def __init__(
250+
self,
251+
min_frequency: Optional[int] = None,
252+
max_categories: Optional[int] = None,
253+
):
254+
if max_categories is not None and max_categories < 2:
255+
raise ValueError(
256+
f"max_categories has to be larger than or equal to 2, input is {max_categories}."
257+
)
258+
self.min_frequency = min_frequency
259+
self.max_categories = max_categories
260+
self._bqml_model: Optional[core.BqmlModel] = None
261+
self._bqml_model_factory = globals.bqml_model_factory()
262+
self._base_sql_generator = globals.base_sql_generator()
263+
264+
# TODO(garrettwu): implement __hash__
265+
def __eq__(self, other: Any) -> bool:
266+
return (
267+
type(other) is LabelEncoder
268+
and self._bqml_model == other._bqml_model
269+
and self.min_frequency == other.min_frequency
270+
and self.max_categories == other.max_categories
271+
)
272+
273+
def _compile_to_sql(self, columns: List[str]) -> List[Tuple[str, str]]:
274+
"""Compile this transformer to a list of SQL expressions that can be included in
275+
a BQML TRANSFORM clause
276+
277+
Args:
278+
columns:
279+
a list of column names to transform
280+
281+
Returns: a list of tuples of (sql_expression, output_name)"""
282+
283+
# minus one here since BQML's inplimentation always includes index 0, and top_k is on top of that.
284+
top_k = (
285+
(self.max_categories - 1)
286+
if self.max_categories is not None
287+
else LabelEncoder.TOP_K_DEFAULT
288+
)
289+
frequency_threshold = (
290+
self.min_frequency
291+
if self.min_frequency is not None
292+
else LabelEncoder.FREQUENCY_THRESHOLD_DEFAULT
293+
)
294+
return [
295+
(
296+
self._base_sql_generator.ml_label_encoder(
297+
column, top_k, frequency_threshold, f"labelencoded_{column}"
298+
),
299+
f"labelencoded_{column}",
300+
)
301+
for column in columns
302+
]
303+
304+
@classmethod
305+
def _parse_from_sql(cls, sql: str) -> tuple[LabelEncoder, str]:
306+
"""Parse SQL to tuple(LabelEncoder, column_label).
307+
308+
Args:
309+
sql: SQL string of format "ML.LabelEncoder({col_label}, {top_k}, {frequency_threshold}) OVER() "
310+
311+
Returns:
312+
tuple(LabelEncoder, column_label)"""
313+
s = sql[sql.find("(") + 1 : sql.find(")")]
314+
col_label, top_k, frequency_threshold = s.split(", ")
315+
max_categories = int(top_k) + 1
316+
min_frequency = int(frequency_threshold)
317+
318+
return cls(min_frequency, max_categories), col_label
319+
320+
def fit(
321+
self,
322+
X: Union[bpd.DataFrame, bpd.Series],
323+
y=None, # ignored
324+
) -> LabelEncoder:
325+
(X,) = utils.convert_to_dataframe(X)
326+
327+
compiled_transforms = self._compile_to_sql(X.columns.tolist())
328+
transform_sqls = [transform_sql for transform_sql, _ in compiled_transforms]
329+
330+
self._bqml_model = self._bqml_model_factory.create_model(
331+
X,
332+
options={"model_type": "transform_only"},
333+
transforms=transform_sqls,
334+
)
335+
336+
# The schema of TRANSFORM output is not available in the model API, so save it during fitting
337+
self._output_names = [name for _, name in compiled_transforms]
338+
return self
339+
340+
def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
341+
if not self._bqml_model:
342+
raise RuntimeError("Must be fitted before transform")
343+
344+
(X,) = utils.convert_to_dataframe(X)
345+
346+
df = self._bqml_model.transform(X)
347+
return typing.cast(
348+
bpd.DataFrame,
349+
df[self._output_names],
350+
)

bigframes/ml/sql.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ def ml_one_hot_encoder(
8888
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-one-hot-encoder for params."""
8989
return f"""ML.ONE_HOT_ENCODER({numeric_expr_sql}, '{drop}', {top_k}, {frequency_threshold}) OVER() AS {name}"""
9090

91+
def ml_label_encoder(
92+
self,
93+
numeric_expr_sql: str,
94+
top_k: int,
95+
frequency_threshold: int,
96+
name: str,
97+
) -> str:
98+
"""Encode ML.LABEL_ENCODER for BQML.
99+
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-label-encoder for params."""
100+
return f"""ML.LABEL_ENCODER({numeric_expr_sql}, {top_k}, {frequency_threshold}) OVER() AS {name}"""
101+
91102

92103
class ModelCreationSqlGenerator(BaseSqlGenerator):
93104
"""Sql generator for creating a model entity. Model id is the standalone id without project id and dataset id."""

tests/system/large/ml/test_pipeline.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,11 @@ def test_pipeline_columntransformer_fit_predict(session, penguins_df_default_ind
570570
preprocessing.StandardScaler(),
571571
["culmen_length_mm", "flipper_length_mm"],
572572
),
573+
(
574+
"label",
575+
preprocessing.LabelEncoder(),
576+
"species",
577+
),
573578
]
574579
),
575580
),
@@ -632,6 +637,11 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id
632637
preprocessing.StandardScaler(),
633638
["culmen_length_mm", "flipper_length_mm"],
634639
),
640+
(
641+
"label",
642+
preprocessing.LabelEncoder(),
643+
"species",
644+
),
635645
]
636646
),
637647
),
@@ -650,7 +660,7 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id
650660

651661
assert isinstance(pl_loaded._transform, compose.ColumnTransformer)
652662
transformers = pl_loaded._transform.transformers_
653-
assert len(transformers) == 3
663+
assert len(transformers) == 4
654664

655665
assert transformers[0][0] == "ont_hot_encoder"
656666
assert isinstance(transformers[0][1], preprocessing.OneHotEncoder)
@@ -660,13 +670,20 @@ def test_pipeline_columntransformer_to_gbq(penguins_df_default_index, dataset_id
660670
assert one_hot_encoder.max_categories == 100
661671
assert transformers[0][2] == "species"
662672

663-
assert transformers[1][0] == "standard_scaler"
664-
assert isinstance(transformers[1][1], preprocessing.StandardScaler)
665-
assert transformers[1][2] == "culmen_length_mm"
673+
assert transformers[1][0] == "label_encoder"
674+
assert isinstance(transformers[1][1], preprocessing.LabelEncoder)
675+
one_hot_encoder = transformers[1][1]
676+
assert one_hot_encoder.min_frequency == 0
677+
assert one_hot_encoder.max_categories == 1000001
678+
assert transformers[1][2] == "species"
666679

667680
assert transformers[2][0] == "standard_scaler"
668681
assert isinstance(transformers[2][1], preprocessing.StandardScaler)
669-
assert transformers[2][2] == "flipper_length_mm"
682+
assert transformers[2][2] == "culmen_length_mm"
683+
684+
assert transformers[3][0] == "standard_scaler"
685+
assert isinstance(transformers[2][1], preprocessing.StandardScaler)
686+
assert transformers[3][2] == "flipper_length_mm"
670687

671688
assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
672689
assert pl_loaded._estimator.fit_intercept is False
@@ -735,3 +752,37 @@ def test_pipeline_one_hot_encoder_to_gbq(penguins_df_default_index, dataset_id):
735752

736753
assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
737754
assert pl_loaded._estimator.fit_intercept is False
755+
756+
757+
def test_pipeline_label_encoder_to_gbq(penguins_df_default_index, dataset_id):
758+
pl = pipeline.Pipeline(
759+
[
760+
(
761+
"transform",
762+
preprocessing.LabelEncoder(min_frequency=5, max_categories=100),
763+
),
764+
("estimator", linear_model.LinearRegression(fit_intercept=False)),
765+
]
766+
)
767+
768+
df = penguins_df_default_index.dropna()
769+
X_train = df[
770+
[
771+
"sex",
772+
"species",
773+
]
774+
]
775+
y_train = df[["body_mass_g"]]
776+
pl.fit(X_train, y_train)
777+
778+
pl_loaded = pl.to_gbq(
779+
f"{dataset_id}.test_penguins_pipeline_label_encoder", replace=True
780+
)
781+
assert isinstance(pl_loaded._transform, preprocessing.LabelEncoder)
782+
783+
label_encoder = pl_loaded._transform
784+
assert label_encoder.min_frequency == 5
785+
assert label_encoder.max_categories == 100
786+
787+
assert isinstance(pl_loaded._estimator, linear_model.LinearRegression)
788+
assert pl_loaded._estimator.fit_intercept is False

0 commit comments

Comments
 (0)