Skip to content

Commit 236e13a

Browse files
Count vectorizer (with Actors) (#705)
1 parent ac6af85 commit 236e13a

File tree

4 files changed

+246
-4
lines changed

4 files changed

+246
-4
lines changed

dask_ml/feature_extraction/text.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
"""
22
Utilities to build feature vectors from text documents.
33
"""
4+
import itertools
5+
46
import dask
57
import dask.array as da
68
import dask.bag as db
79
import dask.dataframe as dd
10+
import distributed
811
import numpy as np
912
import scipy.sparse
1013
import sklearn.base
1114
import sklearn.feature_extraction.text
15+
from dask.delayed import Delayed
16+
from distributed import get_client, wait
17+
from sklearn.utils.validation import check_is_fitted
1218

1319

1420
class _BaseHasher(sklearn.base.BaseEstimator):
@@ -108,3 +114,168 @@ class FeatureHasher(_BaseHasher, sklearn.feature_extraction.text.FeatureHasher):
108114
@property
109115
def _hasher(self):
110116
return sklearn.feature_extraction.text.FeatureHasher
117+
118+
119+
class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
120+
"""Convert a collection of text documents to a matrix of token counts
121+
122+
Notes
123+
-----
124+
When a vocabulary isn't provided, ``fit_transform`` requires two
125+
passes over the dataset: one to learn the vocabulary and a second
126+
to transform the data. Consider persisting the data if it fits
127+
in (distributed) memory prior to calling ``fit`` or ``transform``
128+
when not providing a ``vocabulary``.
129+
130+
Additionally, this implementation benefits from having
131+
an active ``dask.distributed.Client``, even on a single machine.
132+
When a client is present, the learned ``vocabulary`` is persisted
133+
in distributed memory, which saves some recompuation and redundant
134+
communication.
135+
136+
See Also
137+
--------
138+
sklearn.feature_extraction.text.CountVectorizer
139+
140+
Examples
141+
--------
142+
The Dask-ML implementation currently requires that ``raw_documents``
143+
is a :class:`dask.bag.Bag` of documents (lists of strings).
144+
145+
>>> from dask_ml.feature_extraction.text import CountVectorizer
146+
>>> import dask.bag as db
147+
>>> from distributed import Client
148+
>>> client = Client()
149+
>>> corpus = [
150+
... 'This is the first document.',
151+
... 'This document is the second document.',
152+
... 'And this is the third one.',
153+
... 'Is this the first document?',
154+
... ]
155+
>>> corpus = db.from_sequence(corpus, npartitions=2)
156+
>>> vectorizer = CountVectorizer()
157+
>>> X = vectorizer.fit_transform(corpus)
158+
dask.array<concatenate, shape=(nan, 9), dtype=int64, chunksize=(nan, 9), ...
159+
chunktype=scipy.csr_matrix>
160+
>>> X.compute().toarray()
161+
array([[0, 1, 1, 1, 0, 0, 1, 0, 1],
162+
[0, 2, 0, 1, 0, 1, 1, 0, 1],
163+
[1, 0, 0, 1, 1, 0, 1, 1, 1],
164+
[0, 1, 1, 1, 0, 0, 1, 0, 1]])
165+
>>> vectorizer.get_feature_names()
166+
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
167+
"""
168+
169+
def fit_transform(self, raw_documents, y=None):
170+
params = self.get_params()
171+
vocabulary = params.pop("vocabulary")
172+
173+
vocabulary_for_transform = vocabulary
174+
175+
if self.vocabulary is not None:
176+
# Case 1: Just map transform.
177+
fixed_vocabulary = True
178+
n_features = vocabulary_length(vocabulary)
179+
vocabulary_ = vocabulary
180+
else:
181+
fixed_vocabulary = False
182+
# Case 2: learn vocabulary from the data.
183+
vocabularies = raw_documents.map_partitions(_build_vocabulary, params)
184+
vocabulary = vocabulary_for_transform = _merge_vocabulary(
185+
*vocabularies.to_delayed()
186+
)
187+
vocabulary_for_transform = vocabulary_for_transform.persist()
188+
vocabulary_ = vocabulary.compute()
189+
190+
n_features = len(vocabulary_)
191+
result = raw_documents.map_partitions(
192+
_count_vectorizer_transform, vocabulary_for_transform, params
193+
)
194+
195+
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
196+
result = build_array(result, n_features, meta)
197+
198+
self.vocabulary_ = vocabulary_
199+
self.fixed_vocabulary_ = fixed_vocabulary
200+
201+
return result
202+
203+
def transform(self, raw_documents):
204+
params = self.get_params()
205+
vocabulary = params.pop("vocabulary")
206+
207+
if vocabulary is None:
208+
check_is_fitted(self, "vocabulary_")
209+
vocabulary_for_transform = self.vocabulary_
210+
else:
211+
if isinstance(vocabulary, dict):
212+
# scatter for the user
213+
try:
214+
client = get_client()
215+
except ValueError:
216+
vocabulary_for_transform = dask.delayed(vocabulary)
217+
else:
218+
(vocabulary_for_transform,) = client.scatter(
219+
(vocabulary,), broadcast=True
220+
)
221+
else:
222+
vocabulary_for_transform = vocabulary
223+
224+
n_features = vocabulary_length(vocabulary_for_transform)
225+
transformed = raw_documents.map_partitions(
226+
_count_vectorizer_transform, vocabulary_for_transform, params
227+
)
228+
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
229+
return build_array(transformed, n_features, meta)
230+
231+
232+
def build_array(bag, n_features, meta):
233+
name = "from-bag-" + bag.name
234+
layer = {(name, i, 0): (k, i) for k, i in bag.__dask_keys__()}
235+
dsk = dask.highlevelgraph.HighLevelGraph.from_collections(
236+
name, layer, dependencies=[bag]
237+
)
238+
chunks = ((np.nan,) * bag.npartitions, (n_features,))
239+
return da.Array(dsk, name, chunks, meta=meta)
240+
241+
242+
def vocabulary_length(vocabulary):
243+
if isinstance(vocabulary, dict):
244+
return len(vocabulary)
245+
elif isinstance(vocabulary, Delayed):
246+
try:
247+
return len(vocabulary)
248+
except TypeError:
249+
return len(vocabulary.compute())
250+
elif isinstance(vocabulary, distributed.Future):
251+
client = get_client()
252+
future = client.submit(len, vocabulary)
253+
wait(future)
254+
result = future.result()
255+
return result
256+
else:
257+
raise ValueError(f"Unknown vocabulary type {type(vocabulary)}.")
258+
259+
260+
def _count_vectorizer_transform(partition, vocabulary, params):
261+
model = sklearn.feature_extraction.text.CountVectorizer(
262+
vocabulary=vocabulary, **params
263+
)
264+
return model.transform(partition)
265+
266+
267+
def _build_vocabulary(partition, params):
268+
model = sklearn.feature_extraction.text.CountVectorizer(**params)
269+
model.fit(partition)
270+
return set(model.vocabulary_)
271+
272+
273+
@dask.delayed
274+
def _merge_vocabulary(*vocabularies):
275+
vocabulary = {
276+
key: i
277+
for i, key in enumerate(
278+
sorted(set(itertools.chain.from_iterable(vocabularies)))
279+
)
280+
}
281+
return vocabulary

dask_ml/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def assert_estimator_equal(left, right, exclude=None, **kwargs):
9696
for attr in left_attrs2:
9797
l = getattr(left, attr)
9898
r = getattr(right, attr)
99-
_assert_eq(l, r, **kwargs)
99+
_assert_eq(l, r, name=attr, **kwargs)
100100

101101

102102
def check_array(
@@ -193,7 +193,7 @@ def check_array(
193193
return sk_validation.check_array(array, *args, **kwargs)
194194

195195

196-
def _assert_eq(l, r, **kwargs):
196+
def _assert_eq(l, r, name=None, **kwargs):
197197
array_types = (np.ndarray, da.Array)
198198
frame_types = (pd.core.generic.NDFrame, dd._Frame)
199199
if isinstance(l, array_types):
@@ -206,7 +206,7 @@ def _assert_eq(l, r, **kwargs):
206206
for a, b in zip(l, r):
207207
_assert_eq(a, b, **kwargs)
208208
else:
209-
assert l == r
209+
assert l == r, (name, l, r)
210210

211211

212212
def check_random_state(random_state):

docs/source/modules/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,11 @@ with Dask Arrays or DataFrames.
176176
:toctree: generated/
177177
:template: class.rst
178178

179+
feature_extraction.text.CountVectorizer
179180
feature_extraction.text.HashingVectorizer
180181
feature_extraction.text.FeatureHasher
181182

182-
183+
183184
:mod:`dask_ml.compose`: Composite Estimators
184185
============================================
185186

tests/feature_extraction/test_text.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import pytest
77
import scipy.sparse
88
import sklearn.feature_extraction.text
9+
from distributed import Client
910

1011
import dask_ml.feature_extraction.text
12+
from dask_ml._compat import dummy_context
1113
from dask_ml.utils import assert_estimator_equal
1214

1315
JUNK_FOOD_DOCS = (
@@ -107,3 +109,71 @@ def test_correct_meta():
107109
assert scipy.sparse.issparse(result._meta)
108110
assert result._meta.dtype == "float64"
109111
assert result._meta.shape == (0, 0)
112+
113+
114+
@pytest.mark.parametrize("give_vocabulary", [True, False])
115+
@pytest.mark.parametrize("distributed", [True, False])
116+
def test_count_vectorizer(give_vocabulary, distributed):
117+
m1 = sklearn.feature_extraction.text.CountVectorizer()
118+
b = db.from_sequence(JUNK_FOOD_DOCS, npartitions=2)
119+
r1 = m1.fit_transform(JUNK_FOOD_DOCS)
120+
121+
if give_vocabulary:
122+
vocabulary = m1.vocabulary_
123+
m1 = sklearn.feature_extraction.text.CountVectorizer(vocabulary=vocabulary)
124+
r1 = m1.transform(JUNK_FOOD_DOCS)
125+
else:
126+
vocabulary = None
127+
128+
m2 = dask_ml.feature_extraction.text.CountVectorizer(vocabulary=vocabulary)
129+
130+
if distributed:
131+
client = Client() # noqa
132+
else:
133+
client = dummy_context()
134+
135+
if give_vocabulary:
136+
r2 = m2.transform(b)
137+
else:
138+
r2 = m2.fit_transform(b)
139+
140+
with client:
141+
exclude = {"vocabulary_actor_", "stop_words_"}
142+
if give_vocabulary:
143+
# In scikit-learn, `.transform()` sets these.
144+
# This looks buggy.
145+
exclude |= {"vocabulary_", "fixed_vocabulary_"}
146+
147+
assert_estimator_equal(m1, m2, exclude=exclude)
148+
assert isinstance(r2, da.Array)
149+
assert isinstance(r2._meta, scipy.sparse.csr_matrix)
150+
np.testing.assert_array_equal(r1.toarray(), r2.compute().toarray())
151+
152+
r3 = m2.transform(b)
153+
assert isinstance(r3, da.Array)
154+
assert isinstance(r3._meta, scipy.sparse.csr_matrix)
155+
np.testing.assert_array_equal(r1.toarray(), r3.compute().toarray())
156+
157+
if give_vocabulary:
158+
r4 = m2.fit_transform(b)
159+
assert isinstance(r4, da.Array)
160+
assert isinstance(r4._meta, scipy.sparse.csr_matrix)
161+
np.testing.assert_array_equal(r1.toarray(), r4.compute().toarray())
162+
163+
164+
def test_count_vectorizer_remote_vocabulary():
165+
m1 = sklearn.feature_extraction.text.CountVectorizer().fit(JUNK_FOOD_DOCS)
166+
vocabulary = m1.vocabulary_
167+
r1 = m1.transform(JUNK_FOOD_DOCS)
168+
b = db.from_sequence(JUNK_FOOD_DOCS, npartitions=2)
169+
170+
with Client() as client:
171+
(remote_vocabulary,) = client.scatter((vocabulary,), broadcast=True)
172+
m = dask_ml.feature_extraction.text.CountVectorizer(
173+
vocabulary=remote_vocabulary
174+
)
175+
r2 = m.transform(b)
176+
177+
assert isinstance(r2, da.Array)
178+
assert isinstance(r2._meta, scipy.sparse.csr_matrix)
179+
np.testing.assert_array_equal(r1.toarray(), r2.compute().toarray())

0 commit comments

Comments
 (0)