Skip to content

Commit cd73658

Browse files
committed
make tests pass
1 parent d9c036a commit cd73658

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

graphtools/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def _get_param_names(cls):
6565

6666
return parameters
6767

68+
def set_params(self, **kwargs):
69+
return self
70+
6871

6972
class Data(Base):
7073
"""Parent class that handles the import and dimensionality reduction of data
@@ -202,6 +205,7 @@ def set_params(self, **params):
202205
raise ValueError("Cannot update n_pca. Please create a new graph")
203206
if 'random_state' in params:
204207
self.random_state = params['random_state']
208+
super().set_params(**params)
205209
return self
206210

207211
def transform(self, Y):
@@ -441,6 +445,7 @@ def set_params(self, **params):
441445
params['kernel_symm'] != self.kernel_symm:
442446
raise ValueError(
443447
"Cannot update kernel_symm. Please create a new graph")
448+
super().set_params(**params)
444449
return self
445450

446451
@property

graphtools/graphs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -681,14 +681,14 @@ def set_params(self, **params):
681681
raise ValueError("Cannot update precomputed. "
682682
"Please create a new graph")
683683
if 'distance' in params and params['distance'] != self.distance and \
684-
self.precomputed is not None:
684+
self.precomputed is None:
685685
raise ValueError("Cannot update distance. "
686686
"Please create a new graph")
687687
if 'knn' in params and params['knn'] != self.knn and \
688-
self.precomputed is not None:
688+
self.precomputed is None:
689689
raise ValueError("Cannot update knn. Please create a new graph")
690690
if 'decay' in params and params['decay'] != self.decay and \
691-
self.precomputed is not None:
691+
self.precomputed is None:
692692
raise ValueError("Cannot update decay. Please create a new graph")
693693
# update superclass parameters
694694
super().set_params(**params)

test/test_landmark.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,18 @@ def test_verbose():
132132
def test_set_params():
133133
G = build_graph(data, n_landmark=500, decay=None)
134134
G.landmark_op
135-
assert G.get_params == {'n_pca': 20,
136-
'random_state': 42,
137-
'kernel_symm': '+',
138-
'gamma': None,
139-
'n_landmark': 500,
140-
'knn': 3,
141-
'decay': None,
142-
'distance':
143-
'euclidean',
144-
'thresh': 0,
145-
'n_jobs': -1,
146-
'verbose': 0}
135+
assert G.get_params() == {'n_pca': 20,
136+
'random_state': 42,
137+
'kernel_symm': '+',
138+
'gamma': None,
139+
'n_landmark': 500,
140+
'knn': 3,
141+
'decay': None,
142+
'distance':
143+
'euclidean',
144+
'thresh': 0,
145+
'n_jobs': -1,
146+
'verbose': 0}
147147
G.set_params(n_landmark=300)
148148
assert G.landmark_op.shape == (300, 300)
149149
G.set_params(n_landmark=G.n_landmark, n_svd=G.n_svd)

0 commit comments

Comments
 (0)