|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import scipy.sparse as sp |
5 | | -from sklearn.feature_extraction import DictVectorizer |
| 5 | +from pyspark import AccumulatorParam |
6 | 6 | from sklearn.externals import six |
| 7 | +from sklearn.feature_extraction import DictVectorizer |
7 | 8 |
|
| 9 | +from ..base import SparkBroadcasterMixin |
8 | 10 | from ..rdd import DictRDD |
9 | | -from pyspark import AccumulatorParam |
10 | 11 | from ..utils.validation import check_rdd |
11 | | -from ..base import SparkBroadcasterMixin |
12 | 12 |
|
13 | 13 |
|
14 | 14 | class SparkDictVectorizer(DictVectorizer, SparkBroadcasterMixin): |
@@ -87,7 +87,6 @@ def fit(self, Z): |
87 | 87 | self |
88 | 88 | """ |
89 | 89 | X = Z[:, 'X'] if isinstance(Z, DictRDD) else Z |
90 | | - check_rdd(X, (np.ndarray,)) |
91 | 90 |
|
92 | 91 | """Create vocabulary |
93 | 92 | """ |
@@ -142,9 +141,6 @@ def transform(self, Z): |
142 | 141 | Z : transformed, containing {array, sparse matrix} |
143 | 142 | Feature vectors; always 2-d. |
144 | 143 | """ |
145 | | - X = Z[:, 'X'] if isinstance(Z, DictRDD) else Z |
146 | | - check_rdd(X, (np.ndarray, sp.spmatrix)) |
147 | | - |
148 | 144 | mapper = self.broadcast(super(SparkDictVectorizer, self).transform, |
149 | 145 | Z.context) |
150 | 146 | dtype = sp.spmatrix if self.sparse else np.ndarray |
|
0 commit comments