@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
19
19
20
20
import java .util .Locale
21
21
22
+ import org .apache .spark .ml .Pipeline
22
23
import org .apache .spark .ml .util .{DefaultReadWriteTest , MLTest }
23
24
import org .apache .spark .sql .{DataFrame , Row }
24
25
@@ -181,12 +182,19 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
181
182
}
182
183
183
184
test(" read/write" ) {
184
- val t = new StopWordsRemover ()
185
+ val t1 = new StopWordsRemover ()
185
186
.setInputCol(" myInputCol" )
186
187
.setOutputCol(" myOutputCol" )
187
188
.setStopWords(Array (" the" , " a" ))
188
189
.setCaseSensitive(true )
189
- testDefaultReadWrite(t)
190
+ testDefaultReadWrite(t1)
191
+
192
+ val t2 = new StopWordsRemover ()
193
+ .setInputCols(Array (" input1" , " input2" , " input3" ))
194
+ .setOutputCols(Array (" result1" , " result2" , " result3" ))
195
+ .setStopWords(Array (" the" , " a" ))
196
+ .setCaseSensitive(true )
197
+ testDefaultReadWrite(t2)
190
198
}
191
199
192
200
test(" StopWordsRemover output column already exists" ) {
@@ -199,7 +207,7 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
199
207
testTransformerByInterceptingException[(Array [String ], Array [String ])](
200
208
dataSet,
201
209
remover,
202
- s " requirement failed: Column $outputCol already exists. " ,
210
+ s " requirement failed: Output Column $outputCol already exists. " ,
203
211
" expected" )
204
212
}
205
213
@@ -217,4 +225,123 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
217
225
Locale .setDefault(oldDefault)
218
226
}
219
227
}
228
+
229
+ test(" Multiple Columns: StopWordsRemover default" ) {
230
+ val remover = new StopWordsRemover ()
231
+ .setInputCols(Array (" raw1" , " raw2" ))
232
+ .setOutputCols(Array (" filtered1" , " filtered2" ))
233
+ val df = Seq (
234
+ (Seq (" test" , " test" ), Seq (" test1" , " test2" ), Seq (" test" , " test" ), Seq (" test1" , " test2" )),
235
+ (Seq (" a" , " b" , " c" , " d" ), Seq (" a" , " b" ), Seq (" b" , " c" , " d" ), Seq (" b" )),
236
+ (Seq (" a" , " the" , " an" ), Seq (" the" , " an" ), Seq (), Seq ()),
237
+ (Seq (" A" , " The" , " AN" ), Seq (" A" , " The" ), Seq (), Seq ()),
238
+ (Seq (null ), Seq (null ), Seq (null ), Seq (null )),
239
+ (Seq (), Seq (), Seq (), Seq ())
240
+ ).toDF(" raw1" , " raw2" , " expected1" , " expected2" )
241
+
242
+ remover.transform(df)
243
+ .select(" filtered1" , " expected1" , " filtered2" , " expected2" )
244
+ .collect().foreach {
245
+ case Row (r1 : Seq [String ], e1 : Seq [String ], r2 : Seq [String ], e2 : Seq [String ]) =>
246
+ assert(r1 === e1,
247
+ s " The result value is not correct after bucketing. Expected $e1 but found $r1" )
248
+ assert(r2 === e2,
249
+ s " The result value is not correct after bucketing. Expected $e2 but found $r2" )
250
+ }
251
+ }
252
+
253
+ test(" Multiple Columns: StopWordsRemover with particular stop words list" ) {
254
+ val stopWords = Array (" test" , " a" , " an" , " the" )
255
+ val remover = new StopWordsRemover ()
256
+ .setInputCols(Array (" raw1" , " raw2" ))
257
+ .setOutputCols(Array (" filtered1" , " filtered2" ))
258
+ .setStopWords(stopWords)
259
+ val df = Seq (
260
+ (Seq (" test" , " test" ), Seq (" test1" , " test2" ), Seq (), Seq (" test1" , " test2" )),
261
+ (Seq (" a" , " b" , " c" , " d" ), Seq (" a" , " b" ), Seq (" b" , " c" , " d" ), Seq (" b" )),
262
+ (Seq (" a" , " the" , " an" ), Seq (" a" , " the" , " test1" ), Seq (), Seq (" test1" )),
263
+ (Seq (" A" , " The" , " AN" ), Seq (" A" , " The" , " AN" ), Seq (), Seq ()),
264
+ (Seq (null ), Seq (null ), Seq (null ), Seq (null )),
265
+ (Seq (), Seq (), Seq (), Seq ())
266
+ ).toDF(" raw1" , " raw2" , " expected1" , " expected2" )
267
+
268
+ remover.transform(df)
269
+ .select(" filtered1" , " expected1" , " filtered2" , " expected2" )
270
+ .collect().foreach {
271
+ case Row (r1 : Seq [String ], e1 : Seq [String ], r2 : Seq [String ], e2 : Seq [String ]) =>
272
+ assert(r1 === e1,
273
+ s " The result value is not correct after bucketing. Expected $e1 but found $r1" )
274
+ assert(r2 === e2,
275
+ s " The result value is not correct after bucketing. Expected $e2 but found $r2" )
276
+ }
277
+ }
278
+
279
+ test(" Compare single/multiple column(s) StopWordsRemover in pipeline" ) {
280
+ val df = Seq (
281
+ (Seq (" test" , " test" ), Seq (" test1" , " test2" )),
282
+ (Seq (" a" , " b" , " c" , " d" ), Seq (" a" , " b" )),
283
+ (Seq (" a" , " the" , " an" ), Seq (" a" , " the" , " test1" )),
284
+ (Seq (" A" , " The" , " AN" ), Seq (" A" , " The" , " AN" )),
285
+ (Seq (null ), Seq (null )),
286
+ (Seq (), Seq ())
287
+ ).toDF(" input1" , " input2" )
288
+
289
+ val multiColsRemover = new StopWordsRemover ()
290
+ .setInputCols(Array (" input1" , " input2" ))
291
+ .setOutputCols(Array (" output1" , " output2" ))
292
+
293
+ val plForMultiCols = new Pipeline ()
294
+ .setStages(Array (multiColsRemover))
295
+ .fit(df)
296
+
297
+ val removerForCol1 = new StopWordsRemover ()
298
+ .setInputCol(" input1" )
299
+ .setOutputCol(" output1" )
300
+ val removerForCol2 = new StopWordsRemover ()
301
+ .setInputCol(" input2" )
302
+ .setOutputCol(" output2" )
303
+
304
+ val plForSingleCol = new Pipeline ()
305
+ .setStages(Array (removerForCol1, removerForCol2))
306
+ .fit(df)
307
+
308
+ val resultForSingleCol = plForSingleCol.transform(df)
309
+ .select(" output1" , " output2" )
310
+ .collect()
311
+ val resultForMultiCols = plForMultiCols.transform(df)
312
+ .select(" output1" , " output2" )
313
+ .collect()
314
+
315
+ resultForSingleCol.zip(resultForMultiCols).foreach {
316
+ case (rowForSingle, rowForMultiCols) =>
317
+ assert(rowForSingle === rowForMultiCols)
318
+ }
319
+ }
320
+
321
+ test(" Multiple Columns: Mismatched sizes of inputCols/outputCols" ) {
322
+ val remover = new StopWordsRemover ()
323
+ .setInputCols(Array (" input1" ))
324
+ .setOutputCols(Array (" result1" , " result2" ))
325
+ val df = Seq (
326
+ (Seq (" A" ), Seq (" A" )),
327
+ (Seq (" The" , " the" ), Seq (" The" ))
328
+ ).toDF(" input1" , " input2" )
329
+ intercept[IllegalArgumentException ] {
330
+ remover.transform(df).count()
331
+ }
332
+ }
333
+
334
+ test(" Multiple Columns: Set both of inputCol/inputCols" ) {
335
+ val remover = new StopWordsRemover ()
336
+ .setInputCols(Array (" input1" , " input2" ))
337
+ .setOutputCols(Array (" result1" , " result2" ))
338
+ .setInputCol(" input1" )
339
+ val df = Seq (
340
+ (Seq (" A" ), Seq (" A" )),
341
+ (Seq (" The" , " the" ), Seq (" The" ))
342
+ ).toDF(" input1" , " input2" )
343
+ intercept[IllegalArgumentException ] {
344
+ remover.transform(df).count()
345
+ }
346
+ }
220
347
}
0 commit comments