Skip to content

Commit 8d4da15

Browse files
authored
fix: exclude index columns from model fitting processes. (#1138)
* fix: exclude index columns from model fitting processes. * update logic * fix unit test * remove empty line
1 parent a61eb4d commit 8d4da15

File tree

4 files changed

+56
-6
lines changed

4 files changed

+56
-6
lines changed

bigframes/ml/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,11 @@ def create_model(
307307
# Cache dataframes to make sure base table is not a snapshot
308308
# cached dataframe creates a full copy, never uses snapshot
309309
if y_train is None:
310-
input_data = X_train.cache()
310+
input_data = X_train.reset_index(drop=True).cache()
311311
else:
312-
input_data = X_train.join(y_train, how="outer").cache()
312+
input_data = (
313+
X_train.join(y_train, how="outer").reset_index(drop=True).cache()
314+
)
313315
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})
314316

315317
session = X_train._session

tests/system/large/ml/test_cluster.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,13 @@ def test_cluster_configure_fit_load_params(penguins_df_default_index, dataset_id
154154
assert reloaded_model.distance_type == "COSINE"
155155
assert reloaded_model.max_iter == 30
156156
assert reloaded_model.tol == 0.001
157+
158+
159+
def test_model_centroids_with_custom_index(penguins_df_default_index):
160+
model = cluster.KMeans(n_clusters=3)
161+
penguins = penguins_df_default_index.set_index(["species", "island", "sex"])
162+
model.fit(penguins)
163+
164+
assert (
165+
not model.cluster_centers_["feature"].isin(["species", "island", "sex"]).any()
166+
)

tests/system/large/ml/test_linear_model.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,30 @@ def test_logistic_regression_customized_params_fit_score(
425425
assert reloaded_model.tol == 0.02
426426
assert reloaded_model.learning_rate_strategy == "CONSTANT"
427427
assert reloaded_model.learning_rate == 0.2
428+
429+
430+
def test_model_centroids_with_custom_index(penguins_df_default_index):
431+
model = bigframes.ml.linear_model.LogisticRegression(
432+
fit_intercept=False,
433+
class_weight="balanced",
434+
l2_reg=0.2,
435+
tol=0.02,
436+
l1_reg=0.2,
437+
max_iterations=30,
438+
optimize_strategy="batch_gradient_descent",
439+
learning_rate_strategy="constant",
440+
learning_rate=0.2,
441+
)
442+
df = penguins_df_default_index.dropna().set_index(["species", "island"])
443+
X_train = df[
444+
[
445+
"culmen_length_mm",
446+
"culmen_depth_mm",
447+
"flipper_length_mm",
448+
]
449+
]
450+
y_train = df[["sex"]]
451+
model.fit(X_train, y_train)
452+
453+
# If this line executes without errors, the model has correctly ignored the custom index columns
454+
model.predict(X_train.reset_index(drop=True))

tests/unit/ml/test_golden_sql.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,17 @@ def mock_X(mock_y, mock_session):
8585
["index_column_id"],
8686
["index_column_label"],
8787
)
88+
89+
mock_X.join(mock_y).reset_index(drop=True).sql = "input_X_y_no_index_sql"
90+
mock_X.join(mock_y).reset_index(drop=True).cache.return_value = mock_X.join(
91+
mock_y
92+
).reset_index(drop=True)
93+
mock_X.join(mock_y).reset_index(drop=True)._to_sql_query.return_value = (
94+
"input_X_y_no_index_sql",
95+
["index_column_id"],
96+
["index_column_label"],
97+
)
98+
8899
mock_X.cache.return_value = mock_X
89100

90101
return mock_X
@@ -107,7 +118,7 @@ def test_linear_regression_default_fit(
107118
model.fit(mock_X, mock_y)
108119

109120
mock_session._start_query_ml_ddl.assert_called_once_with(
110-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql"
121+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
111122
)
112123

113124

@@ -117,7 +128,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X,
117128
model.fit(mock_X, mock_y)
118129

119130
mock_session._start_query_ml_ddl.assert_called_once_with(
120-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql"
131+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
121132
)
122133

123134

@@ -150,7 +161,7 @@ def test_logistic_regression_default_fit(
150161
model.fit(mock_X, mock_y)
151162

152163
mock_session._start_query_ml_ddl.assert_called_once_with(
153-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql"
164+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
154165
)
155166

156167

@@ -172,7 +183,7 @@ def test_logistic_regression_params_fit(
172183
model.fit(mock_X, mock_y)
173184

174185
mock_session._start_query_ml_ddl.assert_called_once_with(
175-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_sql"
186+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
176187
)
177188

178189

0 commit comments

Comments
 (0)