Skip to content

Commit 103e998

Browse files
fix: Escape ids more consistently in ml module (#1074)
1 parent 8d74269 commit 103e998

File tree

11 files changed

+173
-142
lines changed

11 files changed

+173
-142
lines changed

bigframes/core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def label_to_identifier(label: typing.Hashable, strict: bool = False) -> str:
116116
"""
117117
# Column values will be loaded as null if the column name has spaces.
118118
# https://github.com/googleapis/python-bigquery/issues/1566
119-
identifier = str(label).replace(" ", "_")
120-
119+
identifier = str(label)
121120
if strict:
121+
identifier = str(label).replace(" ", "_")
122122
identifier = re.sub(r"[^a-zA-Z0-9_]", "", identifier)
123123
if not identifier:
124124
identifier = "id"

bigframes/ml/compose.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud import bigquery
2929

3030
from bigframes.core import log_adapter
31+
import bigframes.core.compile.googlesql as sql_utils
3132
from bigframes.ml import base, core, globals, impute, preprocessing, utils
3233
import bigframes.pandas as bpd
3334

@@ -98,25 +99,22 @@ class SQLScalarColumnTransformer:
9899
def __init__(self, sql: str, target_column: str = "transformed_{0}"):
99100
super().__init__()
100101
self._sql = sql
102+
# TODO: More robust unescaping
101103
self._target_column = target_column.replace("`", "")
102104

103105
PLAIN_COLNAME_RX = re.compile("^[a-z][a-z0-9_]*$", re.IGNORECASE)
104106

105-
def escape(self, colname: str):
106-
colname = colname.replace("`", "")
107-
if self.PLAIN_COLNAME_RX.match(colname):
108-
return colname
109-
return f"`{colname}`"
110-
111107
def _compile_to_sql(
112108
self, X: bpd.DataFrame, columns: Optional[Iterable[str]] = None
113109
) -> List[str]:
114110
if columns is None:
115111
columns = X.columns
116112
result = []
117113
for column in columns:
118-
current_sql = self._sql.format(self.escape(column))
119-
current_target_column = self.escape(self._target_column.format(column))
114+
current_sql = self._sql.format(sql_utils.identifier(column))
115+
current_target_column = sql_utils.identifier(
116+
self._target_column.format(column)
117+
)
120118
result.append(f"{current_sql} AS {current_target_column}")
121119
return result
122120

@@ -239,6 +237,7 @@ def camel_to_snake(name):
239237
transformers_set.add(
240238
(
241239
camel_to_snake(transformer_cls.__name__),
240+
# TODO: This is very fragile, use real SQL parser
242241
*transformer_cls._parse_from_sql(transform_sql), # type: ignore
243242
)
244243
)
@@ -253,7 +252,7 @@ def camel_to_snake(name):
253252

254253
target_column = transform_col_dict["name"]
255254
sql_transformer = SQLScalarColumnTransformer(
256-
transform_sql, target_column=target_column
255+
transform_sql.strip(), target_column=target_column
257256
)
258257
input_column_name = f"?{target_column}"
259258
transformers_set.add(

bigframes/ml/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ class BqmlModel(BaseBqml):
4747
def __init__(self, session: bigframes.Session, model: bigquery.Model):
4848
self._session = session
4949
self._model = model
50+
model_ref = self._model.reference
51+
assert model_ref is not None
5052
self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator(
51-
self.model_name
53+
model_ref
5254
)
5355

5456
def _apply_ml_tvf(

bigframes/ml/impute.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[SimpleImputer, str]:
8080
tuple(SimpleImputer, column_label)"""
8181
s = sql[sql.find("(") + 1 : sql.find(")")]
8282
col_label, strategy = s.split(", ")
83-
return cls(strategy[1:-1]), col_label # type: ignore[arg-type]
83+
return cls(strategy[1:-1]), _unescape_id(col_label) # type: ignore[arg-type]
8484

8585
def fit(
8686
self,
@@ -110,3 +110,11 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
110110
bpd.DataFrame,
111111
df[self._output_names],
112112
)
113+
114+
115+
def _unescape_id(id: str) -> str:
116+
"""Very simple conversion to removed ` characters from ids.
117+
118+
A proper sql parser should be used instead.
119+
"""
120+
return id.removeprefix("`").removesuffix("`")

bigframes/ml/preprocessing.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[StandardScaler, str]:
7676
Returns:
7777
tuple(StandardScaler, column_label)"""
7878
col_label = sql[sql.find("(") + 1 : sql.find(")")]
79-
return cls(), col_label
79+
return cls(), _unescape_id(col_label)
8080

8181
def fit(
8282
self,
@@ -152,8 +152,9 @@ def _parse_from_sql(cls, sql: str) -> tuple[MaxAbsScaler, str]:
152152
153153
Returns:
154154
tuple(MaxAbsScaler, column_label)"""
155+
# TODO: Use real sql parser
155156
col_label = sql[sql.find("(") + 1 : sql.find(")")]
156-
return cls(), col_label
157+
return cls(), _unescape_id(col_label)
157158

158159
def fit(
159160
self,
@@ -229,8 +230,9 @@ def _parse_from_sql(cls, sql: str) -> tuple[MinMaxScaler, str]:
229230
230231
Returns:
231232
tuple(MinMaxScaler, column_label)"""
233+
# TODO: Use real sql parser
232234
col_label = sql[sql.find("(") + 1 : sql.find(")")]
233-
return cls(), col_label
235+
return cls(), _unescape_id(col_label)
234236

235237
def fit(
236238
self,
@@ -349,11 +351,11 @@ def _parse_from_sql(cls, sql: str) -> tuple[KBinsDiscretizer, str]:
349351

350352
if sql.startswith("ML.QUANTILE_BUCKETIZE"):
351353
num_bins = s.split(",")[1]
352-
return cls(int(num_bins), "quantile"), col_label
354+
return cls(int(num_bins), "quantile"), _unescape_id(col_label)
353355
else:
354356
array_split_points = s[s.find("[") + 1 : s.find("]")]
355357
n_bins = array_split_points.count(",") + 2
356-
return cls(n_bins, "uniform"), col_label
358+
return cls(n_bins, "uniform"), _unescape_id(col_label)
357359

358360
def fit(
359361
self,
@@ -469,7 +471,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[OneHotEncoder, str]:
469471
max_categories = int(top_k) + 1
470472
min_frequency = int(frequency_threshold)
471473

472-
return cls(drop, min_frequency, max_categories), col_label
474+
return cls(drop, min_frequency, max_categories), _unescape_id(col_label)
473475

474476
def fit(
475477
self,
@@ -578,7 +580,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[LabelEncoder, str]:
578580
max_categories = int(top_k) + 1
579581
min_frequency = int(frequency_threshold)
580582

581-
return cls(min_frequency, max_categories), col_label
583+
return cls(min_frequency, max_categories), _unescape_id(col_label)
582584

583585
def fit(
584586
self,
@@ -661,7 +663,7 @@ def _parse_from_sql(cls, sql: str) -> tuple[PolynomialFeatures, tuple[str, ...]]
661663
col_labels = sql[sql.find("STRUCT(") + 7 : sql.find(")")].split(",")
662664
col_labels = [label.strip() for label in col_labels]
663665
degree = int(sql[sql.rfind(",") + 1 : sql.rfind(")")])
664-
return cls(degree), tuple(col_labels)
666+
return cls(degree), tuple(map(_unescape_id, col_labels))
665667

666668
def fit(
667669
self,
@@ -694,6 +696,14 @@ def transform(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
694696
)
695697

696698

699+
def _unescape_id(id: str) -> str:
700+
"""Very simple conversion to removed ` characters from ids.
701+
702+
A proper sql parser should be used instead.
703+
"""
704+
return id.removeprefix("`").removesuffix("`")
705+
706+
697707
PreprocessingType = Union[
698708
OneHotEncoder,
699709
StandardScaler,

0 commit comments

Comments
 (0)