|
16 | 16 |
|
17 | 17 | import scipy.sparse as sp |
18 | 18 |
|
19 | | -from modAL.utils.data import data_vstack, modALinput, retrieve_rows |
| 19 | +from modAL.utils.data import data_vstack, data_hstack, modALinput, retrieve_rows |
20 | 20 |
|
21 | 21 | if sys.version_info >= (3, 4): |
22 | 22 | ABC = abc.ABC |
@@ -143,13 +143,7 @@ def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.cs |
143 | 143 |
|
144 | 144 | ################################ |
145 | 145 | # concatenate all transformations and return |
146 | | - # TODO: maybe use a newly implemented data_hstack() instead |
147 | | - |
148 | | - # use sparse representation if any of the pipelines do |
149 | | - if any([isinstance(Xti, sp.csr_matrix) for Xti in Xt]): |
150 | | - return sp.hstack([sp.csc_matrix(Xti) for Xti in Xt]) |
151 | | - |
152 | | - return np.hstack(Xt) |
| 146 | + return data_hstack(Xt) |
153 | 147 |
|
154 | 148 | def _fit_to_known(self, bootstrap: bool = False, **fit_kwargs) -> 'BaseLearner': |
155 | 149 | """ |
@@ -297,12 +291,15 @@ class BaseCommittee(ABC, BaseEstimator): |
297 | 291 | Args: |
298 | 292 | learner_list: List of ActiveLearner objects to form committee. |
299 | 293 | query_strategy: Function to query labels. |
| 294 | + on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator |
| 295 | + when applying the query strategy. |
300 | 296 | """ |
301 | | - def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable) -> None: |
| 297 | + def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable, on_transformed: bool = False) -> None: |
302 | 298 | assert type(learner_list) == list, 'learners must be supplied in a list' |
303 | 299 |
|
304 | 300 | self.learner_list = learner_list |
305 | 301 | self.query_strategy = query_strategy |
| 302 | + self.on_transformed = on_transformed |
306 | 303 |
|
307 | 304 | def __iter__(self) -> Iterator[BaseLearner]: |
308 | 305 | for learner in self.learner_list: |
@@ -369,6 +366,17 @@ def fit(self, X: modALinput, y: modALinput, **fit_kwargs) -> 'BaseCommittee': |
369 | 366 |
|
370 | 367 | return self |
371 | 368 |
|
| 369 | + def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.csr_matrix]: |
| 370 | + """ |
| 371 | + Transforms the data as supplied to each learner's estimator and concatenates transformations. |
| 372 | + Args: |
| 373 | + X: dataset to be transformed |
| 374 | +
|
| 375 | + Returns: |
| 376 | + Transformed data set |
| 377 | + """ |
| 378 | + return data_hstack([learner.transform_without_estimating(X) for learner in self.learner_list]) |
| 379 | + |
372 | 380 | def query(self, X_pool, *query_args, **query_kwargs) -> Union[Tuple, modALinput]: |
373 | 381 | """ |
374 | 382 | Finds the n_instances most informative point in the data provided by calling the query_strategy function. |
|
0 commit comments