Skip to content

Commit f97ba24

Browse files
Fixed vocabulary length for remote vocabulary (#719)
* Fixed vocabulary length for remote vocabulary Closes #712 * optimize fit, transform
1 parent 7f6dcb5 commit f97ba24

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

dask_ml/feature_extraction/text.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ def fit_transform(self, raw_documents, y=None):
186186
)
187187
vocabulary_for_transform = vocabulary_for_transform.persist()
188188
vocabulary_ = vocabulary.compute()
189+
n_features = len(vocabulary_)
189190

190-
n_features = len(vocabulary_)
191191
result = raw_documents.map_partitions(
192192
_count_vectorizer_transform, vocabulary_for_transform, params
193193
)
@@ -206,20 +206,20 @@ def transform(self, raw_documents):
206206

207207
if vocabulary is None:
208208
check_is_fitted(self, "vocabulary_")
209-
vocabulary_for_transform = self.vocabulary_
210-
else:
211-
if isinstance(vocabulary, dict):
212-
# scatter for the user
213-
try:
214-
client = get_client()
215-
except ValueError:
216-
vocabulary_for_transform = dask.delayed(vocabulary)
217-
else:
218-
(vocabulary_for_transform,) = client.scatter(
219-
(vocabulary,), broadcast=True
220-
)
209+
vocabulary = self.vocabulary_
210+
211+
if isinstance(vocabulary, dict):
212+
# scatter for the user
213+
try:
214+
client = get_client()
215+
except ValueError:
216+
vocabulary_for_transform = dask.delayed(vocabulary)
221217
else:
222-
vocabulary_for_transform = vocabulary
218+
(vocabulary_for_transform,) = client.scatter(
219+
(vocabulary,), broadcast=True
220+
)
221+
else:
222+
vocabulary_for_transform = vocabulary
223223

224224
n_features = vocabulary_length(vocabulary_for_transform)
225225
transformed = raw_documents.map_partitions(

tests/feature_extraction/test_text.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,9 @@ def test_count_vectorizer_remote_vocabulary():
177177
assert isinstance(r2, da.Array)
178178
assert isinstance(r2._meta, scipy.sparse.csr_matrix)
179179
np.testing.assert_array_equal(r1.toarray(), r2.compute().toarray())
180+
181+
m = dask_ml.feature_extraction.text.CountVectorizer(
182+
vocabulary=remote_vocabulary
183+
)
184+
m.fit_transform(b)
185+
assert m.vocabulary_ is remote_vocabulary

0 commit comments

Comments
 (0)