@@ -186,8 +186,8 @@ def fit_transform(self, raw_documents, y=None):
186
186
)
187
187
vocabulary_for_transform = vocabulary_for_transform .persist ()
188
188
vocabulary_ = vocabulary .compute ()
189
+ n_features = len (vocabulary_ )
189
190
190
- n_features = len (vocabulary_ )
191
191
result = raw_documents .map_partitions (
192
192
_count_vectorizer_transform , vocabulary_for_transform , params
193
193
)
@@ -206,20 +206,20 @@ def transform(self, raw_documents):
206
206
207
207
if vocabulary is None :
208
208
check_is_fitted (self , "vocabulary_" )
209
- vocabulary_for_transform = self .vocabulary_
210
- else :
211
- if isinstance (vocabulary , dict ):
212
- # scatter for the user
213
- try :
214
- client = get_client ()
215
- except ValueError :
216
- vocabulary_for_transform = dask .delayed (vocabulary )
217
- else :
218
- (vocabulary_for_transform ,) = client .scatter (
219
- (vocabulary ,), broadcast = True
220
- )
209
+ vocabulary = self .vocabulary_
210
+
211
+ if isinstance (vocabulary , dict ):
212
+ # scatter for the user
213
+ try :
214
+ client = get_client ()
215
+ except ValueError :
216
+ vocabulary_for_transform = dask .delayed (vocabulary )
221
217
else :
222
- vocabulary_for_transform = vocabulary
218
+ (vocabulary_for_transform ,) = client .scatter (
219
+ (vocabulary ,), broadcast = True
220
+ )
221
+ else :
222
+ vocabulary_for_transform = vocabulary
223
223
224
224
n_features = vocabulary_length (vocabulary_for_transform )
225
225
transformed = raw_documents .map_partitions (
0 commit comments