Skip to content

Commit 0a2d24a

Browse files
author
Max Keller
committed
Add DeepActiveLearner tests & fix ValueError
1 parent fd4506d commit 0a2d24a

File tree

3 files changed

+395
-182
lines changed

3 files changed

+395
-182
lines changed

modAL/models/base.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def query(self, X_pool, return_metrics: bool = False, *query_args, **query_kwarg
177177
query_result, query_metrics = self.query_strategy(
178178
self, X_pool, *query_args, **query_kwargs)
179179

180-
except TypeError:
180+
except ValueError:
181181
query_metrics = None
182182
query_result = self.query_strategy(
183183
self, X_pool, *query_args, **query_kwargs)
@@ -246,22 +246,10 @@ def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **f
246246
for learner in self.learner_list:
247247
learner._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
248248

249-
@abc.abstractmethod
250-
def fit(self, X: modALinput, y: modALinput, **fit_kwargs) -> Any:
251-
pass
252-
253249
@abc.abstractmethod
254250
def predict(self, X: modALinput) -> Any:
255251
pass
256252

257-
@abc.abstractmethod
258-
def predict_proba(self, X: modALinput, **predict_proba_kwargs) -> Any:
259-
pass
260-
261-
@abc.abstractmethod
262-
def score(self, X: modALinput, y: modALinput, sample_weight: List[float] = None) -> Any:
263-
pass
264-
265253
@abc.abstractmethod
266254
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwargs) -> Any:
267255
pass
@@ -300,7 +288,7 @@ def query(self, X_pool, return_metrics: bool = False, *query_args, **query_kwarg
300288
query_result, query_metrics = self.query_strategy(
301289
self, X_pool, *query_args, **query_kwargs)
302290

303-
except TypeError:
291+
except ValueError:
304292
query_metrics = None
305293
query_result = self.query_strategy(
306294
self, X_pool, *query_args, **query_kwargs)
@@ -340,6 +328,3 @@ def _set_classes(self):
340328
def vote(self, X: modALinput) -> Any: # TODO: clarify typing
341329
pass
342330

343-
@abc.abstractmethod
344-
def vote_proba(self, X: modALinput, **predict_proba_kwargs) -> Any:
345-
pass

modAL/utils/data.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
2727
return np.concatenate(blocks)
2828
elif isinstance(blocks[0], list):
2929
return np.concatenate(blocks).tolist()
30-
elif torch.is_tensor(blocks[0]):
30+
elif torch.is_tensor(blocks[0]):
3131
return torch.cat(blocks)
3232

3333
raise TypeError('%s datatype is not supported' % type(blocks[0]))
@@ -51,23 +51,22 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
5151
return np.hstack(blocks)
5252
elif isinstance(blocks[0], list):
5353
return np.hstack(blocks).tolist()
54-
elif torch.is_tensor(blocks[0]):
54+
elif torch.is_tensor(blocks[0]):
5555
return torch.cat(blocks, dim=1)
5656

5757
TypeError('%s datatype is not supported' % type(blocks[0]))
5858

5959

60-
def add_row(X:modALinput, row: modALinput):
60+
def add_row(X: modALinput, row: modALinput):
6161
"""
6262
Returns X' =
6363
6464
[X
6565
66-
row]
67-
"""
66+
row] """
6867
if isinstance(X, np.ndarray):
6968
return np.vstack((X, row))
70-
elif torch.is_tensor(X):
69+
elif torch.is_tensor(X):
7170
return torch.cat((X, row))
7271
elif isinstance(X, list):
7372
return np.vstack((X, row)).tolist()
@@ -102,7 +101,7 @@ def retrieve_rows(X: modALinput,
102101
return X.iloc[I]
103102
elif isinstance(X, list):
104103
return np.array(X)[I].tolist()
105-
elif isinstance(X, dict):
104+
elif isinstance(X, dict):
106105
X_return = {}
107106
for key, value in X.items():
108107
X_return[key] = retrieve_rows(value, I)
@@ -118,7 +117,6 @@ def retrieve_rows(X: modALinput,
118117
def drop_rows(X: modALinput,
119118
I: Union[int, List[int], np.ndarray]) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
120119
"""
121-
TODO: Add pytorch support
122120
Returns X without the row(s) at index/indices I
123121
"""
124122
if sp.issparse(X):
@@ -131,6 +129,9 @@ def drop_rows(X: modALinput,
131129
return np.delete(X, I, axis=0)
132130
elif isinstance(X, list):
133131
return np.delete(X, I, axis=0).tolist()
132+
elif torch.is_tensor(X):
133+
return X[[True if row not in I else False
134+
for row in range(X.size(0))]]
134135

135136
raise TypeError('%s datatype is not supported' % type(X))
136137

@@ -165,7 +166,7 @@ def data_shape(X: modALinput):
165166
return X.shape
166167
elif isinstance(X, list):
167168
return np.array(X).shape
168-
elif torch.is_tensor(X):
169+
elif torch.is_tensor(X):
169170
return tuple(X.size())
170171

171172
raise TypeError('%s datatype is not supported' % type(X))

0 commit comments

Comments
 (0)