Skip to content

Commit e950533

Browse files
feat: Allow DataFrame.join for self-join on Null index (#860)
* feat: Allow DataFrame.join for self-join on Null index * fix ml caching to apply post-join, add test * fix ml golden sql test * change unordered test to use linear regression
1 parent 8e04c38 commit e950533

File tree

5 files changed

+68
-9
lines changed

5 files changed

+68
-9
lines changed

bigframes/core/blocks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,11 +2307,11 @@ def join(
23072307
f"Only how='outer','left','right','inner' currently supported. {constants.FEEDBACK_LINK}"
23082308
)
23092309
# Handle null index, which only supports row join
2310-
if (self.index.nlevels == other.index.nlevels == 0) and not block_identity_join:
2311-
if not block_identity_join:
2312-
result = try_row_join(self, other, how=how)
2313-
if result is not None:
2314-
return result
2310+
# This is the canonical way of aligning on null index, so always allow (ignore block_identity_join)
2311+
if self.index.nlevels == other.index.nlevels == 0:
2312+
result = try_row_join(self, other, how=how)
2313+
if result is not None:
2314+
return result
23152315
raise bigframes.exceptions.NullIndexError(
23162316
"Cannot implicitly align objects. Set an explicit index using set_index."
23172317
)

bigframes/ml/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def distance(
8383
"""
8484
assert len(x.columns) == 1 and len(y.columns) == 1
8585

86-
input_data = x.cache().join(y.cache(), how="outer")
86+
input_data = x.join(y, how="outer").cache()
8787
x_column_id, y_column_id = x._block.value_columns[0], y._block.value_columns[0]
8888

8989
return self._apply_sql(
@@ -326,7 +326,7 @@ def create_model(
326326
if y_train is None:
327327
input_data = X_train.cache()
328328
else:
329-
input_data = X_train.cache().join(y_train.cache(), how="outer")
329+
input_data = X_train.join(y_train, how="outer").cache()
330330
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})
331331

332332
session = X_train._session
@@ -366,7 +366,7 @@ def create_llm_remote_model(
366366
options = dict(options)
367367
# Cache dataframes to make sure base table is not a snapshot
368368
# cached dataframe creates a full copy, never uses snapshot
369-
input_data = X_train.cache().join(y_train.cache(), how="outer")
369+
input_data = X_train.join(y_train, how="outer").cache()
370370
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})
371371

372372
session = X_train._session
@@ -399,7 +399,7 @@ def create_time_series_model(
399399
options = dict(options)
400400
# Cache dataframes to make sure base table is not a snapshot
401401
# cached dataframe creates a full copy, never uses snapshot
402-
input_data = X_train.cache().join(y_train.cache(), how="outer")
402+
input_data = X_train.join(y_train, how="outer").cache()
403403
options.update({"TIME_SERIES_TIMESTAMP_COL": X_train.columns.tolist()[0]})
404404
options.update({"TIME_SERIES_DATA_COL": y_train.columns.tolist()[0]})
405405

tests/system/large/ml/test_linear_model.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,50 @@ def test_linear_regression_customized_params_fit_score(
111111
assert reloaded_model.learning_rate == 0.2
112112

113113

114+
def test_unordered_mode_regression_configure_fit_score(
115+
unordered_session, penguins_table_id, dataset_id
116+
):
117+
model = bigframes.ml.linear_model.LinearRegression()
118+
119+
df = unordered_session.read_gbq(penguins_table_id).dropna()
120+
X_train = df[
121+
[
122+
"species",
123+
"island",
124+
"culmen_length_mm",
125+
"culmen_depth_mm",
126+
"flipper_length_mm",
127+
"sex",
128+
]
129+
]
130+
y_train = df[["body_mass_g"]]
131+
model.fit(X_train, y_train)
132+
133+
# Check score to ensure the model was fitted
134+
result = model.score(X_train, y_train).to_pandas()
135+
utils.check_pandas_df_schema_and_index(
136+
result, columns=utils.ML_REGRESSION_METRICS, index=1
137+
)
138+
139+
# save, load, check parameters to ensure configuration was kept
140+
reloaded_model = model.to_gbq(f"{dataset_id}.temp_configured_model", replace=True)
141+
assert reloaded_model._bqml_model is not None
142+
assert (
143+
f"{dataset_id}.temp_configured_model" in reloaded_model._bqml_model.model_name
144+
)
145+
assert reloaded_model.optimize_strategy == "NORMAL_EQUATION"
146+
assert reloaded_model.fit_intercept is True
147+
assert reloaded_model.calculate_p_values is False
148+
assert reloaded_model.enable_global_explain is False
149+
assert reloaded_model.l1_reg is None
150+
assert reloaded_model.l2_reg == 0.0
151+
assert reloaded_model.learning_rate is None
152+
assert reloaded_model.learning_rate_strategy == "line_search"
153+
assert reloaded_model.ls_init_learning_rate is None
154+
assert reloaded_model.max_iterations == 20
155+
assert reloaded_model.tol == 0.01
156+
157+
114158
# TODO(garrettwu): add tests for param warm_start. Requires a trained model.
115159

116160

tests/system/small/test_null_index.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def test_null_index_stack(scalars_df_null_index, scalars_pandas_df_default_index
201201
)
202202

203203

204+
def test_null_index_series_self_join(
205+
scalars_df_null_index, scalars_pandas_df_default_index
206+
):
207+
bf_result = scalars_df_null_index[["int64_col"]].join(
208+
scalars_df_null_index[["int64_too"]]
209+
)
210+
pd_result = scalars_pandas_df_default_index[["int64_col"]].join(
211+
scalars_pandas_df_default_index[["int64_too"]]
212+
)
213+
pd.testing.assert_frame_equal(
214+
bf_result.to_pandas(), pd_result.reset_index(drop=True), check_dtype=False
215+
)
216+
217+
204218
def test_null_index_series_self_aligns(
205219
scalars_df_null_index, scalars_pandas_df_default_index
206220
):

tests/unit/ml/test_golden_sql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def mock_X(mock_y, mock_session):
7878
["index_column_label"],
7979
)
8080
mock_X.join(mock_y).sql = "input_X_y_sql"
81+
mock_X.join(mock_y).cache.return_value = mock_X.join(mock_y)
8182
mock_X.join(mock_y)._to_sql_query.return_value = (
8283
"input_X_y_sql",
8384
["index_column_id"],

0 commit comments

Comments
 (0)