Skip to content

Commit 59df7f7

Browse files
authored
feat: add fit_predict method to ml unsupervised models (#2320)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 83da622 commit 59df7f7

File tree

5 files changed

+69
-2
lines changed

5 files changed

+69
-2
lines changed

bigframes/ml/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ def fit(
248248
) -> _T:
249249
return self._fit(X, y)
250250

251+
def fit_predict(
252+
self: _T,
253+
X: utils.ArrayType,
254+
y: Optional[utils.ArrayType] = None,
255+
) -> _T:
256+
return self.fit(X).predict(X)
257+
251258

252259
class RetriableRemotePredictor(BaseEstimator):
253260
def _predict_and_retry(

notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,7 +1736,7 @@
17361736
"provenance": []
17371737
},
17381738
"kernelspec": {
1739-
"display_name": "Python 3 (ipykernel)",
1739+
"display_name": "venv (3.10.14)",
17401740
"language": "python",
17411741
"name": "python3"
17421742
},
@@ -1750,7 +1750,7 @@
17501750
"name": "python",
17511751
"nbconvert_exporter": "python",
17521752
"pygments_lexer": "ipython3",
1753-
"version": "3.10.9"
1753+
"version": "3.10.14"
17541754
}
17551755
},
17561756
"nbformat": 4,

third_party/bigframes_vendored/sklearn/cluster/_kmeans.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,26 @@ def predict(
115115
"""
116116
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
117117

118+
def fit_predict(
119+
self,
120+
X,
121+
y=None,
122+
):
123+
"""Compute cluster centers and predict cluster index for each sample.
124+
125+
Convenience method; equivalent to calling fit(X) followed by predict(X).
126+
127+
Args:
128+
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
129+
DataFrame of shape (n_samples, n_features). Training data.
130+
y (default None):
131+
Not used, present here for API consistency by convention.
132+
133+
Returns:
134+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted labels.
135+
"""
136+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
137+
118138
def score(
119139
self,
120140
X,

third_party/bigframes_vendored/sklearn/decomposition/_mf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,23 @@ def predict(self, X):
9494
Returns:
9595
bigframes.dataframe.DataFrame: Predicted DataFrames."""
9696
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
97+
98+
def fit_predict(
99+
self,
100+
X,
101+
y=None,
102+
):
103+
"""Fit the model with X and generate a predicted rating for every user-item row combination for a matrix factorization model. on X.
104+
105+
Convenience method; equivalent to calling fit(X) followed by predict(X).
106+
107+
Args:
108+
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
109+
DataFrame of shape (n_samples, n_features). Training data.
110+
y (default None):
111+
Not used, present here for API consistency by convention.
112+
113+
Returns:
114+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted labels.
115+
"""
116+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

third_party/bigframes_vendored/sklearn/decomposition/_pca.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,26 @@ def predict(self, X):
101101
bigframes.dataframe.DataFrame: Predicted DataFrames."""
102102
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
103103

104+
def fit_predict(
105+
self,
106+
X,
107+
y=None,
108+
):
109+
"""Fit the model with X and apply the dimensionality reduction on X.
110+
111+
Convenience method; equivalent to calling fit(X) followed by predict(X).
112+
113+
Args:
114+
X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
115+
DataFrame of shape (n_samples, n_features). Training data.
116+
y (default None):
117+
Not used, present here for API consistency by convention.
118+
119+
Returns:
120+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted labels.
121+
"""
122+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
123+
104124
@property
105125
def components_(self):
106126
"""Principal axes in feature space, representing the directions of maximum variance in the data.

0 commit comments

Comments
 (0)