File tree Expand file tree Collapse file tree 5 files changed +69
-2
lines changed
third_party/bigframes_vendored/sklearn Expand file tree Collapse file tree 5 files changed +69
-2
lines changed Original file line number Diff line number Diff line change @@ -248,6 +248,13 @@ def fit(
248248 ) -> _T :
249249 return self ._fit (X , y )
250250
251+ def fit_predict (
252+ self : _T ,
253+ X : utils .ArrayType ,
254+ y : Optional [utils .ArrayType ] = None ,
255+ ) -> _T :
256+ return self .fit (X ).predict (X )
257+
251258
252259class RetriableRemotePredictor (BaseEstimator ):
253260 def _predict_and_retry (
Original file line number Diff line number Diff line change 17361736 "provenance" : []
17371737 },
17381738 "kernelspec" : {
1739- "display_name" : " Python 3 (ipykernel )" ,
1739+ "display_name" : " venv (3.10.14 )" ,
17401740 "language" : " python" ,
17411741 "name" : " python3"
17421742 },
17501750 "name" : " python" ,
17511751 "nbconvert_exporter" : " python" ,
17521752 "pygments_lexer" : " ipython3" ,
1753- "version" : " 3.10.9 "
1753+ "version" : " 3.10.14 "
17541754 }
17551755 },
17561756 "nbformat" : 4 ,
Original file line number Diff line number Diff line change @@ -115,6 +115,26 @@ def predict(
115115 """
116116 raise NotImplementedError (constants .ABSTRACT_METHOD_ERROR_MESSAGE )
117117
118+ def fit_predict (
119+ self ,
120+ X ,
121+ y = None ,
122+ ):
123+ """Compute cluster centers and predict cluster index for each sample.
124+
125+ Convenience method; equivalent to calling fit(X) followed by predict(X).
126+
127+ Args:
128+ X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
129+ DataFrame of shape (n_samples, n_features). Training data.
130+ y (default None):
131+ Not used, present here for API consistency by convention.
132+
133+ Returns:
134+ bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted labels.
135+ """
136+ raise NotImplementedError (constants .ABSTRACT_METHOD_ERROR_MESSAGE )
137+
118138 def score (
119139 self ,
120140 X ,
Original file line number Diff line number Diff line change @@ -94,3 +94,23 @@ def predict(self, X):
9494 Returns:
9595 bigframes.dataframe.DataFrame: Predicted DataFrames."""
9696 raise NotImplementedError (constants .ABSTRACT_METHOD_ERROR_MESSAGE )
97+
98+ def fit_predict (
99+ self ,
100+ X ,
101+ y = None ,
102+ ):
103+ """Fit the model with X and generate a predicted rating for every user-item row combination for a matrix factorization model. on X.
104+
105+ Convenience method; equivalent to calling fit(X) followed by predict(X).
106+
107+ Args:
108+ X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
109+ DataFrame of shape (n_samples, n_features). Training data.
110+ y (default None):
111+ Not used, present here for API consistency by convention.
112+
113+ Returns:
114+ bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted labels.
115+ """
116+ raise NotImplementedError (constants .ABSTRACT_METHOD_ERROR_MESSAGE )
Original file line number Diff line number Diff line change @@ -101,6 +101,26 @@ def predict(self, X):
101101 bigframes.dataframe.DataFrame: Predicted DataFrames."""
102102 raise NotImplementedError (constants .ABSTRACT_METHOD_ERROR_MESSAGE )
103103
104+ def fit_predict (
105+ self ,
106+ X ,
107+ y = None ,
108+ ):
109+ """Fit the model with X and apply the dimensionality reduction on X.
110+
111+ Convenience method; equivalent to calling fit(X) followed by predict(X).
112+
113+ Args:
114+ X (bigframes.dataframe.DataFrame or bigframes.series.Series or pandas.core.frame.DataFrame or pandas.core.series.Series):
115+ DataFrame of shape (n_samples, n_features). Training data.
116+ y (default None):
117+ Not used, present here for API consistency by convention.
118+
119+ Returns:
120+ bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted labels.
121+ """
122+ raise NotImplementedError (constants .ABSTRACT_METHOD_ERROR_MESSAGE )
123+
104124 @property
105125 def components_ (self ):
106126 """Principal axes in feature space, representing the directions of maximum variance in the data.
You can’t perform that action at this time.
0 commit comments