16
16
from sklearn .metrics import r2_score
17
17
18
18
from ..utils import _validate_type , fill_doc , pinv
19
+ from ._fixes import _check_n_features_3d , validate_data
19
20
from .base import _check_estimator , get_coef
20
21
from .time_delaying_ridge import TimeDelayingRidge
21
22
@@ -125,7 +126,7 @@ def __init__(
125
126
self .tmax = tmax
126
127
self .sfreq = sfreq
127
128
self .feature_names = feature_names
128
- self .estimator = 0.0 if estimator is None else estimator
129
+ self .estimator = estimator
129
130
self .fit_intercept = fit_intercept
130
131
self .scoring = scoring
131
132
self .patterns = patterns
@@ -152,6 +153,19 @@ def __repr__(self): # noqa: D105
152
153
s += f"scored ({ self .scoring } )"
153
154
return f"<ReceptiveField | { s } >"
154
155
156
+ def __sklearn_tags__ (self ):
157
+ """..."""
158
+ from sklearn .utils import RegressorTags
159
+
160
+ tags = super ().__sklearn_tags__ ()
161
+ tags .estimator_type = "regressor"
162
+ tags .regressor_tags = RegressorTags ()
163
+ tags .input_tags .three_d_array = True
164
+ tags .target_tags .one_d_labels = True
165
+ tags .target_tags .multi_output = True
166
+ tags .target_tags .required = True
167
+ return tags
168
+
155
169
def _delay_and_reshape (self , X , y = None ):
156
170
"""Delay and reshape the variables."""
157
171
if not isinstance (self .estimator_ , TimeDelayingRidge ):
@@ -169,6 +183,32 @@ def _delay_and_reshape(self, X, y=None):
169
183
y = y .reshape (- 1 , y .shape [- 1 ], order = "F" )
170
184
return X , y
171
185
186
+ def _check_data (self , X , y = None , reset = False ):
187
+ if reset :
188
+ X , y = validate_data (
189
+ self ,
190
+ X = X ,
191
+ y = y ,
192
+ reset = reset ,
193
+ validate_separately = ( # to take care of 3D y
194
+ dict (allow_nd = True , ensure_2d = False ),
195
+ dict (allow_nd = True , ensure_2d = False ),
196
+ ),
197
+ )
198
+ else :
199
+ X = validate_data (self , X = X , allow_nd = True , ensure_2d = False , reset = reset )
200
+ _check_n_features_3d (self , X , reset )
201
+ return X , y
202
+
203
+ def _validate_params (self , X ):
204
+ if self .scoring not in _SCORERS .keys ():
205
+ raise ValueError (
206
+ f"scoring must be one of { sorted (_SCORERS .keys ())} , got { self .scoring } "
207
+ )
208
+ self .sfreq_ = float (self .sfreq )
209
+ if self .tmin > self .tmax :
210
+ raise ValueError (f"tmin ({ self .tmin } ) must be at most tmax ({ self .tmax } )" )
211
+
172
212
def fit (self , X , y ):
173
213
"""Fit a receptive field model.
174
214
@@ -184,22 +224,18 @@ def fit(self, X, y):
184
224
self : instance
185
225
The instance so you can chain operations.
186
226
"""
187
- if self .scoring not in _SCORERS .keys ():
188
- raise ValueError (
189
- f"scoring must be one of { sorted (_SCORERS .keys ())} , got { self .scoring } "
190
- )
191
- self .sfreq_ = float (self .sfreq )
227
+ X , y = self ._check_data (X , y , reset = True )
228
+ self ._validate_params (X )
192
229
X , y , _ , self ._y_dim = self ._check_dimensions (X , y )
193
230
194
- if self .tmin > self .tmax :
195
- raise ValueError (f"tmin ({ self .tmin } ) must be at most tmax ({ self .tmax } )" )
196
231
# Initialize delays
197
232
self .delays_ = _times_to_delays (self .tmin , self .tmax , self .sfreq_ )
198
233
199
234
# Define the slice that we should use in the middle
200
235
self .valid_samples_ = _delays_to_slice (self .delays_ )
201
236
202
- if isinstance (self .estimator , numbers .Real ):
237
+ if self .estimator is None or isinstance (self .estimator , numbers .Real ):
238
+ alpha = self .estimator if self .estimator is not None else 0.0
203
239
if self .fit_intercept is None :
204
240
self .fit_intercept_ = True
205
241
else :
@@ -208,7 +244,7 @@ def fit(self, X, y):
208
244
self .tmin ,
209
245
self .tmax ,
210
246
self .sfreq_ ,
211
- alpha = self . estimator ,
247
+ alpha = alpha ,
212
248
fit_intercept = self .fit_intercept_ ,
213
249
n_jobs = self .n_jobs ,
214
250
edge_correction = self .edge_correction ,
@@ -259,6 +295,12 @@ def fit(self, X, y):
259
295
260
296
# Inverse-transform model weights
261
297
if self .patterns :
298
+ n_total_samples = n_times * n_epochs
299
+ if n_total_samples < 2 :
300
+ raise ValueError (
301
+ "Cannot compute patterns with only one sample; "
302
+ f"got n_samples = { n_total_samples } ."
303
+ )
262
304
if isinstance (self .estimator_ , TimeDelayingRidge ):
263
305
cov_ = self .estimator_ .cov_ / float (n_times * n_epochs - 1 )
264
306
y = y .reshape (- 1 , y .shape [- 1 ], order = "F" )
@@ -300,7 +342,10 @@ def predict(self, X):
300
342
"""
301
343
if not hasattr (self , "delays_" ):
302
344
raise NotFittedError ("Estimator has not been fit yet." )
345
+
346
+ X , _ = self ._check_data (X )
303
347
X , _ , X_dim = self ._check_dimensions (X , None , predict = True )[:3 ]
348
+
304
349
del _
305
350
# convert to sklearn and back
306
351
pred_shape = X .shape [:- 1 ]
@@ -384,7 +429,10 @@ def _check_dimensions(self, X, y, predict=False):
384
429
)
385
430
else :
386
431
raise ValueError (
387
- f"X must be shape (n_times[, n_epochs], n_features), got { X .shape } "
432
+ "X must be shape (n_times[, n_epochs], n_features), "
433
+ f"got { X .shape } . Reshape your data to 2D or 3D "
434
+ "(e.g., array.reshape(-1, 1) for a single feature, "
435
+ "or array.reshape(1, -1) for a single sample)."
388
436
)
389
437
if y is not None :
390
438
if X .shape [0 ] != y .shape [0 ]:
0 commit comments