Skip to content

Commit 34ef718

Browse files
committed
Create properties for getting sensors
1 parent 313e8e0 commit 34ef718

File tree

4 files changed

+80
-1
lines changed

4 files changed

+80
-1
lines changed

pysensors/classification/_sspoc.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,24 @@ def update_n_basis_modes(self, n_basis_modes, xy, **fit_kws):
401401

402402
@property
403403
def selected_sensors(self):
404-
"""Get the selected sensors."""
404+
"""
405+
Get the indices of the selected sensors.
406+
407+
Returns
408+
-------
409+
sensors: numpy array, shape (n_sensors,)
410+
Indices of the selected sensors.
411+
"""
405412
check_is_fitted(self, "sparse_sensors_")
406413
return self.sparse_sensors_
414+
415+
def get_selected_sensors(self):
416+
"""
417+
Convenience function for getting indices of the selected sensors.
418+
419+
Returns
420+
-------
421+
sensors: numpy array, shape (n_sensors,)
422+
Indices of the selected sensors.
423+
"""
424+
return self.selected_sensors

pysensors/pysensors.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,38 @@ def get_selected_sensors(self):
212212
check_is_fitted(self, "ranked_sensors_")
213213
return self.ranked_sensors_[: self.n_sensors]
214214

215+
@property
216+
def selected_sensors(self):
217+
"""
218+
Get the indices of the sensors chosen by the model.
219+
220+
Returns
221+
-------
222+
sensors: numpy array, shape (n_sensors,)
223+
Indices of the sensors chosen by the model
224+
(i.e. the sensor locations) ranked in descending order
225+
of importance.
226+
"""
227+
return self.get_selected_sensors()
228+
215229
def get_all_sensors(self):
216230
"""
217231
Get a ranked list consisting of all the sensors.
218232
The sensors are given in descending order of importance.
219233
234+
Returns
235+
-------
236+
sensors: numpy array, shape (n_features,)
237+
Indices of sensors in descending order of importance.
238+
"""
239+
return self.all_sensors
240+
241+
@property
242+
def all_sensors(self):
243+
"""
244+
Get a ranked list consisting of all the sensors.
245+
The sensors are given in descending order of importance.
246+
220247
Returns
221248
-------
222249
sensors: numpy array, shape (n_features,)
@@ -249,6 +276,20 @@ def set_number_of_sensors(self, n_sensors):
249276
else:
250277
self.n_sensors = n_sensors
251278

279+
def set_n_sensors(self, n_sensors):
280+
"""
281+
A convenience function accomplishing the same thing as
282+
:meth:`set_number_of_sensors`.
283+
Set ``n_sensors``, the number of sensors to be used for prediction.
284+
285+
Parameters
286+
----------
287+
n_sensors: int
288+
The number of sensors. Must be a positive integer.
289+
Cannot exceed the number of available sensors (n_features).
290+
"""
291+
self.set_number_of_sensors(n_sensors)
292+
252293
def update_n_basis_modes(self, n_basis_modes, x=None):
253294
"""
254295
Re-fit the SensorSelector object using a different value of
@@ -418,6 +459,10 @@ def score(x, y):
418459
return error
419460

420461
def _validate_n_sensors(self):
462+
"""
463+
Check that number of sensors does not exceed the maximimum number
464+
allowed by the chosen basis.
465+
"""
421466
check_is_fitted(self, "basis_matrix_")
422467

423468
# Maximum number of sensors (= dimension of basis vectors)

test/classification/test_sspoc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,11 @@ def test_update_n_basis_modes_unfit_basis(basis, data_binary_classification):
263263
model.update_n_basis_modes(n_basis_modes, (x, y), quiet=True)
264264

265265
assert model.basis_matrix_inverse_.shape[0] == n_basis_modes
266+
267+
268+
def test_sspoc_selector_equivalence(data_multiclass_classification):
269+
x, y, _ = data_multiclass_classification
270+
271+
model = SSPOC().fit(x, y)
272+
273+
np.testing.assert_array_equal(model.get_selected_sensors(), model.selected_sensors)

test/test_pysensors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,11 @@ def test_ccqr_integration(data_random):
255255
model = SensorSelector(optimizer=optimizer).fit(data)
256256

257257
check_is_fitted(model)
258+
259+
260+
def test_sensor_selector_properties(data_random):
261+
data = data_random
262+
model = SensorSelector().fit(data)
263+
264+
assert all(model.get_all_sensors() == model.all_sensors)
265+
assert all(model.get_selected_sensors() == model.selected_sensors)

0 commit comments

Comments
 (0)