Skip to content

Commit 1ad79fe

Browse files
author
Boyan Hristov
committed
#104 - added on_transformed support to BaseCommittee
1 parent 8e0cb25 commit 1ad79fe

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

modAL/models/base.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import scipy.sparse as sp
1818

19-
from modAL.utils.data import data_vstack, modALinput, retrieve_rows
19+
from modAL.utils.data import data_vstack, data_hstack, modALinput, retrieve_rows
2020

2121
if sys.version_info >= (3, 4):
2222
ABC = abc.ABC
@@ -143,13 +143,7 @@ def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.cs
143143

144144
################################
145145
# concatenate all transformations and return
146-
# TODO: maybe use a newly implemented data_hstack() instead
147-
148-
# use sparse representation if any of the pipelines do
149-
if any([isinstance(Xti, sp.csr_matrix) for Xti in Xt]):
150-
return sp.hstack([sp.csc_matrix(Xti) for Xti in Xt])
151-
152-
return np.hstack(Xt)
146+
return data_hstack(Xt)
153147

154148
def _fit_to_known(self, bootstrap: bool = False, **fit_kwargs) -> 'BaseLearner':
155149
"""
@@ -297,12 +291,15 @@ class BaseCommittee(ABC, BaseEstimator):
297291
Args:
298292
learner_list: List of ActiveLearner objects to form committee.
299293
query_strategy: Function to query labels.
294+
on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
295+
when applying the query strategy.
300296
"""
301-
def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable) -> None:
297+
def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable, on_transformed: bool = False) -> None:
302298
assert type(learner_list) == list, 'learners must be supplied in a list'
303299

304300
self.learner_list = learner_list
305301
self.query_strategy = query_strategy
302+
self.on_transformed = on_transformed
306303

307304
def __iter__(self) -> Iterator[BaseLearner]:
308305
for learner in self.learner_list:
@@ -369,6 +366,17 @@ def fit(self, X: modALinput, y: modALinput, **fit_kwargs) -> 'BaseCommittee':
369366

370367
return self
371368

369+
def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.csr_matrix]:
370+
"""
371+
Transforms the data as supplied to each learner's estimator and concatenates transformations.
372+
Args:
373+
X: dataset to be transformed
374+
375+
Returns:
376+
Transformed data set
377+
"""
378+
return data_hstack([learner.transform_without_estimating(X) for learner in self.learner_list])
379+
372380
def query(self, X_pool, *query_args, **query_kwargs) -> Union[Tuple, modALinput]:
373381
"""
374382
Finds the n_instances most informative point in the data provided by calling the query_strategy function.

modAL/utils/data.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union, Container, List
1+
from typing import Union, List, Sequence
22
from itertools import chain
33

44
import numpy as np
@@ -9,9 +9,9 @@
99
modALinput = Union[list, np.ndarray, sp.csr_matrix, pd.DataFrame]
1010

1111

12-
def data_vstack(blocks: Container) -> modALinput:
12+
def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
1313
"""
14-
Stack vertically both sparse and dense arrays.
14+
Stack vertically sparse/dense arrays and pandas data frames.
1515
1616
Args:
1717
blocks: Sequence of modALinput objects.
@@ -34,6 +34,26 @@ def data_vstack(blocks: Container) -> modALinput:
3434
raise TypeError('%s datatype is not supported' % type(blocks[0]))
3535

3636

37+
def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
38+
"""
39+
Stack horizontally both sparse and dense arrays
40+
41+
Args:
42+
blocks: Sequence of modALinput objects.
43+
44+
Returns:
45+
New sequence of horizontally stacked elements.
46+
"""
47+
# use sparse representation if any of the blocks do
48+
if any([sp.issparse(b) for b in blocks]):
49+
return sp.hstack(blocks)
50+
51+
try:
52+
return np.hstack(blocks)
53+
except:
54+
raise TypeError('%s datatype is not supported' % type(blocks[0]))
55+
56+
3757
def retrieve_rows(X: modALinput,
3858
I: Union[int, List[int], np.ndarray]) -> Union[sp.csc_matrix, np.ndarray, pd.DataFrame]:
3959
"""

0 commit comments

Comments
 (0)