|
1 | 1 | """
|
2 | 2 | Utilities to build feature vectors from text documents.
|
3 | 3 | """
|
| 4 | +import itertools |
| 5 | + |
4 | 6 | import dask
|
5 | 7 | import dask.array as da
|
6 | 8 | import dask.bag as db
|
7 | 9 | import dask.dataframe as dd
|
| 10 | +import distributed |
8 | 11 | import numpy as np
|
9 | 12 | import scipy.sparse
|
10 | 13 | import sklearn.base
|
11 | 14 | import sklearn.feature_extraction.text
|
| 15 | +from dask.delayed import Delayed |
| 16 | +from distributed import get_client, wait |
| 17 | +from sklearn.utils.validation import check_is_fitted |
12 | 18 |
|
13 | 19 |
|
14 | 20 | class _BaseHasher(sklearn.base.BaseEstimator):
|
@@ -108,3 +114,168 @@ class FeatureHasher(_BaseHasher, sklearn.feature_extraction.text.FeatureHasher):
|
108 | 114 | @property
|
109 | 115 | def _hasher(self):
|
110 | 116 | return sklearn.feature_extraction.text.FeatureHasher
|
| 117 | + |
| 118 | + |
| 119 | +class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer): |
| 120 | + """Convert a collection of text documents to a matrix of token counts |
| 121 | +
|
| 122 | + Notes |
| 123 | + ----- |
| 124 | + When a vocabulary isn't provided, ``fit_transform`` requires two |
| 125 | + passes over the dataset: one to learn the vocabulary and a second |
| 126 | + to transform the data. Consider persisting the data if it fits |
| 127 | + in (distributed) memory prior to calling ``fit`` or ``transform`` |
| 128 | + when not providing a ``vocabulary``. |
| 129 | +
|
| 130 | + Additionally, this implementation benefits from having |
| 131 | + an active ``dask.distributed.Client``, even on a single machine. |
| 132 | + When a client is present, the learned ``vocabulary`` is persisted |
| 133 | + in distributed memory, which saves some recompuation and redundant |
| 134 | + communication. |
| 135 | +
|
| 136 | + See Also |
| 137 | + -------- |
| 138 | + sklearn.feature_extraction.text.CountVectorizer |
| 139 | +
|
| 140 | + Examples |
| 141 | + -------- |
| 142 | + The Dask-ML implementation currently requires that ``raw_documents`` |
| 143 | + is a :class:`dask.bag.Bag` of documents (lists of strings). |
| 144 | +
|
| 145 | + >>> from dask_ml.feature_extraction.text import CountVectorizer |
| 146 | + >>> import dask.bag as db |
| 147 | + >>> from distributed import Client |
| 148 | + >>> client = Client() |
| 149 | + >>> corpus = [ |
| 150 | + ... 'This is the first document.', |
| 151 | + ... 'This document is the second document.', |
| 152 | + ... 'And this is the third one.', |
| 153 | + ... 'Is this the first document?', |
| 154 | + ... ] |
| 155 | + >>> corpus = db.from_sequence(corpus, npartitions=2) |
| 156 | + >>> vectorizer = CountVectorizer() |
| 157 | + >>> X = vectorizer.fit_transform(corpus) |
| 158 | + dask.array<concatenate, shape=(nan, 9), dtype=int64, chunksize=(nan, 9), ... |
| 159 | + chunktype=scipy.csr_matrix> |
| 160 | + >>> X.compute().toarray() |
| 161 | + array([[0, 1, 1, 1, 0, 0, 1, 0, 1], |
| 162 | + [0, 2, 0, 1, 0, 1, 1, 0, 1], |
| 163 | + [1, 0, 0, 1, 1, 0, 1, 1, 1], |
| 164 | + [0, 1, 1, 1, 0, 0, 1, 0, 1]]) |
| 165 | + >>> vectorizer.get_feature_names() |
| 166 | + ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this'] |
| 167 | + """ |
| 168 | + |
| 169 | + def fit_transform(self, raw_documents, y=None): |
| 170 | + params = self.get_params() |
| 171 | + vocabulary = params.pop("vocabulary") |
| 172 | + |
| 173 | + vocabulary_for_transform = vocabulary |
| 174 | + |
| 175 | + if self.vocabulary is not None: |
| 176 | + # Case 1: Just map transform. |
| 177 | + fixed_vocabulary = True |
| 178 | + n_features = vocabulary_length(vocabulary) |
| 179 | + vocabulary_ = vocabulary |
| 180 | + else: |
| 181 | + fixed_vocabulary = False |
| 182 | + # Case 2: learn vocabulary from the data. |
| 183 | + vocabularies = raw_documents.map_partitions(_build_vocabulary, params) |
| 184 | + vocabulary = vocabulary_for_transform = _merge_vocabulary( |
| 185 | + *vocabularies.to_delayed() |
| 186 | + ) |
| 187 | + vocabulary_for_transform = vocabulary_for_transform.persist() |
| 188 | + vocabulary_ = vocabulary.compute() |
| 189 | + |
| 190 | + n_features = len(vocabulary_) |
| 191 | + result = raw_documents.map_partitions( |
| 192 | + _count_vectorizer_transform, vocabulary_for_transform, params |
| 193 | + ) |
| 194 | + |
| 195 | + meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype) |
| 196 | + result = build_array(result, n_features, meta) |
| 197 | + |
| 198 | + self.vocabulary_ = vocabulary_ |
| 199 | + self.fixed_vocabulary_ = fixed_vocabulary |
| 200 | + |
| 201 | + return result |
| 202 | + |
| 203 | + def transform(self, raw_documents): |
| 204 | + params = self.get_params() |
| 205 | + vocabulary = params.pop("vocabulary") |
| 206 | + |
| 207 | + if vocabulary is None: |
| 208 | + check_is_fitted(self, "vocabulary_") |
| 209 | + vocabulary_for_transform = self.vocabulary_ |
| 210 | + else: |
| 211 | + if isinstance(vocabulary, dict): |
| 212 | + # scatter for the user |
| 213 | + try: |
| 214 | + client = get_client() |
| 215 | + except ValueError: |
| 216 | + vocabulary_for_transform = dask.delayed(vocabulary) |
| 217 | + else: |
| 218 | + (vocabulary_for_transform,) = client.scatter( |
| 219 | + (vocabulary,), broadcast=True |
| 220 | + ) |
| 221 | + else: |
| 222 | + vocabulary_for_transform = vocabulary |
| 223 | + |
| 224 | + n_features = vocabulary_length(vocabulary_for_transform) |
| 225 | + transformed = raw_documents.map_partitions( |
| 226 | + _count_vectorizer_transform, vocabulary_for_transform, params |
| 227 | + ) |
| 228 | + meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype) |
| 229 | + return build_array(transformed, n_features, meta) |
| 230 | + |
| 231 | + |
| 232 | +def build_array(bag, n_features, meta): |
| 233 | + name = "from-bag-" + bag.name |
| 234 | + layer = {(name, i, 0): (k, i) for k, i in bag.__dask_keys__()} |
| 235 | + dsk = dask.highlevelgraph.HighLevelGraph.from_collections( |
| 236 | + name, layer, dependencies=[bag] |
| 237 | + ) |
| 238 | + chunks = ((np.nan,) * bag.npartitions, (n_features,)) |
| 239 | + return da.Array(dsk, name, chunks, meta=meta) |
| 240 | + |
| 241 | + |
| 242 | +def vocabulary_length(vocabulary): |
| 243 | + if isinstance(vocabulary, dict): |
| 244 | + return len(vocabulary) |
| 245 | + elif isinstance(vocabulary, Delayed): |
| 246 | + try: |
| 247 | + return len(vocabulary) |
| 248 | + except TypeError: |
| 249 | + return len(vocabulary.compute()) |
| 250 | + elif isinstance(vocabulary, distributed.Future): |
| 251 | + client = get_client() |
| 252 | + future = client.submit(len, vocabulary) |
| 253 | + wait(future) |
| 254 | + result = future.result() |
| 255 | + return result |
| 256 | + else: |
| 257 | + raise ValueError(f"Unknown vocabulary type {type(vocabulary)}.") |
| 258 | + |
| 259 | + |
| 260 | +def _count_vectorizer_transform(partition, vocabulary, params): |
| 261 | + model = sklearn.feature_extraction.text.CountVectorizer( |
| 262 | + vocabulary=vocabulary, **params |
| 263 | + ) |
| 264 | + return model.transform(partition) |
| 265 | + |
| 266 | + |
| 267 | +def _build_vocabulary(partition, params): |
| 268 | + model = sklearn.feature_extraction.text.CountVectorizer(**params) |
| 269 | + model.fit(partition) |
| 270 | + return set(model.vocabulary_) |
| 271 | + |
| 272 | + |
| 273 | +@dask.delayed |
| 274 | +def _merge_vocabulary(*vocabularies): |
| 275 | + vocabulary = { |
| 276 | + key: i |
| 277 | + for i, key in enumerate( |
| 278 | + sorted(set(itertools.chain.from_iterable(vocabularies))) |
| 279 | + ) |
| 280 | + } |
| 281 | + return vocabulary |
0 commit comments