Skip to content

Commit 1f4075d

Browse files
huaxingaosrowen
authored andcommitted
[SPARK-29808][ML][PYTHON] StopWordsRemover should support multi-cols
### What changes were proposed in this pull request? Add multi-cols support in StopWordsRemover ### Why are the changes needed? As a basic Transformer, StopWordsRemover should support multi-cols. Param stopWords can be applied across all columns. ### Does this PR introduce any user-facing change? ```StopWordsRemover.setInputCols``` ```StopWordsRemover.setOutputCols``` ### How was this patch tested? Unit tests Closes apache#26480 from huaxingao/spark-29808. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 8c2bf64 commit 1f4075d

File tree

3 files changed

+217
-17
lines changed

3 files changed

+217
-17
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,28 @@ import java.util.Locale
2222
import org.apache.spark.annotation.Since
2323
import org.apache.spark.ml.Transformer
2424
import org.apache.spark.ml.param._
25-
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
25+
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
2626
import org.apache.spark.ml.util._
2727
import org.apache.spark.sql.{DataFrame, Dataset}
2828
import org.apache.spark.sql.functions.{col, udf}
29-
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
29+
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
3030

3131
/**
3232
* A feature transformer that filters out stop words from input.
3333
*
34+
* Since 3.0.0, `StopWordsRemover` can filter out multiple columns at once by setting the
35+
* `inputCols` parameter. Note that when both the `inputCol` and `inputCols` parameters are set,
36+
* an Exception will be thrown.
37+
*
3438
* @note null values from input array are preserved unless adding null to stopWords
3539
* explicitly.
3640
*
3741
* @see <a href="http://en.wikipedia.org/wiki/Stop_words">Stop words (Wikipedia)</a>
3842
*/
3943
@Since("1.5.0")
4044
class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String)
41-
extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable {
45+
extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols
46+
with DefaultParamsWritable {
4247

4348
@Since("1.5.0")
4449
def this() = this(Identifiable.randomUID("stopWords"))
@@ -51,6 +56,14 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
5156
@Since("1.5.0")
5257
def setOutputCol(value: String): this.type = set(outputCol, value)
5358

59+
/** @group setParam */
60+
@Since("3.0.0")
61+
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
62+
63+
/** @group setParam */
64+
@Since("3.0.0")
65+
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
66+
5467
/**
5568
* The words to be filtered out.
5669
* Default: English stop words
@@ -121,6 +134,15 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
121134
}
122135
}
123136

137+
/** Returns the input and output column names corresponding in pair. */
138+
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
139+
if (isSet(inputCol)) {
140+
(Array($(inputCol)), Array($(outputCol)))
141+
} else {
142+
($(inputCols), $(outputCols))
143+
}
144+
}
145+
124146
setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
125147
caseSensitive -> false, locale -> getDefaultOrUS.toString)
126148

@@ -142,16 +164,38 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
142164
terms.filter(s => !lowerStopWords.contains(toLower(s)))
143165
}
144166
}
145-
val metadata = outputSchema($(outputCol)).metadata
146-
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
167+
168+
val (inputColNames, outputColNames) = getInOutCols()
169+
val ouputCols = inputColNames.map { inputColName =>
170+
t(col(inputColName))
171+
}
172+
val ouputMetadata = outputColNames.map(outputSchema(_).metadata)
173+
dataset.withColumns(outputColNames, ouputCols, ouputMetadata)
147174
}
148175

149176
@Since("1.5.0")
150177
override def transformSchema(schema: StructType): StructType = {
151-
val inputType = schema($(inputCol)).dataType
152-
require(inputType.sameType(ArrayType(StringType)), "Input type must be " +
153-
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
154-
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
178+
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol),
179+
Seq(outputCols))
180+
181+
if (isSet(inputCols)) {
182+
require(getInputCols.length == getOutputCols.length,
183+
s"StopWordsRemover $this has mismatched Params " +
184+
s"for multi-column transform. Params ($inputCols, $outputCols) should have " +
185+
"equal lengths, but they have different lengths: " +
186+
s"(${getInputCols.length}, ${getOutputCols.length}).")
187+
}
188+
189+
val (inputColNames, outputColNames) = getInOutCols()
190+
val newCols = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
191+
require(!schema.fieldNames.contains(outputColName),
192+
s"Output Column $outputColName already exists.")
193+
val inputType = schema(inputColName).dataType
194+
require(inputType.sameType(ArrayType(StringType)), "Input type must be " +
195+
s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.")
196+
StructField(outputColName, inputType, schema(inputColName).nullable)
197+
}
198+
StructType(schema.fields ++ newCols)
155199
}
156200

157201
@Since("1.5.0")

mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
1919

2020
import java.util.Locale
2121

22+
import org.apache.spark.ml.Pipeline
2223
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
2324
import org.apache.spark.sql.{DataFrame, Row}
2425

@@ -181,12 +182,19 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
181182
}
182183

183184
test("read/write") {
184-
val t = new StopWordsRemover()
185+
val t1 = new StopWordsRemover()
185186
.setInputCol("myInputCol")
186187
.setOutputCol("myOutputCol")
187188
.setStopWords(Array("the", "a"))
188189
.setCaseSensitive(true)
189-
testDefaultReadWrite(t)
190+
testDefaultReadWrite(t1)
191+
192+
val t2 = new StopWordsRemover()
193+
.setInputCols(Array("input1", "input2", "input3"))
194+
.setOutputCols(Array("result1", "result2", "result3"))
195+
.setStopWords(Array("the", "a"))
196+
.setCaseSensitive(true)
197+
testDefaultReadWrite(t2)
190198
}
191199

192200
test("StopWordsRemover output column already exists") {
@@ -199,7 +207,7 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
199207
testTransformerByInterceptingException[(Array[String], Array[String])](
200208
dataSet,
201209
remover,
202-
s"requirement failed: Column $outputCol already exists.",
210+
s"requirement failed: Output Column $outputCol already exists.",
203211
"expected")
204212
}
205213

@@ -217,4 +225,123 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
217225
Locale.setDefault(oldDefault)
218226
}
219227
}
228+
229+
test("Multiple Columns: StopWordsRemover default") {
230+
val remover = new StopWordsRemover()
231+
.setInputCols(Array("raw1", "raw2"))
232+
.setOutputCols(Array("filtered1", "filtered2"))
233+
val df = Seq(
234+
(Seq("test", "test"), Seq("test1", "test2"), Seq("test", "test"), Seq("test1", "test2")),
235+
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
236+
(Seq("a", "the", "an"), Seq("the", "an"), Seq(), Seq()),
237+
(Seq("A", "The", "AN"), Seq("A", "The"), Seq(), Seq()),
238+
(Seq(null), Seq(null), Seq(null), Seq(null)),
239+
(Seq(), Seq(), Seq(), Seq())
240+
).toDF("raw1", "raw2", "expected1", "expected2")
241+
242+
remover.transform(df)
243+
.select("filtered1", "expected1", "filtered2", "expected2")
244+
.collect().foreach {
245+
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
246+
assert(r1 === e1,
247+
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
248+
assert(r2 === e2,
249+
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
250+
}
251+
}
252+
253+
test("Multiple Columns: StopWordsRemover with particular stop words list") {
254+
val stopWords = Array("test", "a", "an", "the")
255+
val remover = new StopWordsRemover()
256+
.setInputCols(Array("raw1", "raw2"))
257+
.setOutputCols(Array("filtered1", "filtered2"))
258+
.setStopWords(stopWords)
259+
val df = Seq(
260+
(Seq("test", "test"), Seq("test1", "test2"), Seq(), Seq("test1", "test2")),
261+
(Seq("a", "b", "c", "d"), Seq("a", "b"), Seq("b", "c", "d"), Seq("b")),
262+
(Seq("a", "the", "an"), Seq("a", "the", "test1"), Seq(), Seq("test1")),
263+
(Seq("A", "The", "AN"), Seq("A", "The", "AN"), Seq(), Seq()),
264+
(Seq(null), Seq(null), Seq(null), Seq(null)),
265+
(Seq(), Seq(), Seq(), Seq())
266+
).toDF("raw1", "raw2", "expected1", "expected2")
267+
268+
remover.transform(df)
269+
.select("filtered1", "expected1", "filtered2", "expected2")
270+
.collect().foreach {
271+
case Row(r1: Seq[String], e1: Seq[String], r2: Seq[String], e2: Seq[String]) =>
272+
assert(r1 === e1,
273+
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
274+
assert(r2 === e2,
275+
s"The result value is not correct after bucketing. Expected $e2 but found $r2")
276+
}
277+
}
278+
279+
test("Compare single/multiple column(s) StopWordsRemover in pipeline") {
280+
val df = Seq(
281+
(Seq("test", "test"), Seq("test1", "test2")),
282+
(Seq("a", "b", "c", "d"), Seq("a", "b")),
283+
(Seq("a", "the", "an"), Seq("a", "the", "test1")),
284+
(Seq("A", "The", "AN"), Seq("A", "The", "AN")),
285+
(Seq(null), Seq(null)),
286+
(Seq(), Seq())
287+
).toDF("input1", "input2")
288+
289+
val multiColsRemover = new StopWordsRemover()
290+
.setInputCols(Array("input1", "input2"))
291+
.setOutputCols(Array("output1", "output2"))
292+
293+
val plForMultiCols = new Pipeline()
294+
.setStages(Array(multiColsRemover))
295+
.fit(df)
296+
297+
val removerForCol1 = new StopWordsRemover()
298+
.setInputCol("input1")
299+
.setOutputCol("output1")
300+
val removerForCol2 = new StopWordsRemover()
301+
.setInputCol("input2")
302+
.setOutputCol("output2")
303+
304+
val plForSingleCol = new Pipeline()
305+
.setStages(Array(removerForCol1, removerForCol2))
306+
.fit(df)
307+
308+
val resultForSingleCol = plForSingleCol.transform(df)
309+
.select("output1", "output2")
310+
.collect()
311+
val resultForMultiCols = plForMultiCols.transform(df)
312+
.select("output1", "output2")
313+
.collect()
314+
315+
resultForSingleCol.zip(resultForMultiCols).foreach {
316+
case (rowForSingle, rowForMultiCols) =>
317+
assert(rowForSingle === rowForMultiCols)
318+
}
319+
}
320+
321+
test("Multiple Columns: Mismatched sizes of inputCols/outputCols") {
322+
val remover = new StopWordsRemover()
323+
.setInputCols(Array("input1"))
324+
.setOutputCols(Array("result1", "result2"))
325+
val df = Seq(
326+
(Seq("A"), Seq("A")),
327+
(Seq("The", "the"), Seq("The"))
328+
).toDF("input1", "input2")
329+
intercept[IllegalArgumentException] {
330+
remover.transform(df).count()
331+
}
332+
}
333+
334+
test("Multiple Columns: Set both of inputCol/inputCols") {
335+
val remover = new StopWordsRemover()
336+
.setInputCols(Array("input1", "input2"))
337+
.setOutputCols(Array("result1", "result2"))
338+
.setInputCol("input1")
339+
val df = Seq(
340+
(Seq("A"), Seq("A")),
341+
(Seq("The", "the"), Seq("The"))
342+
).toDF("input1", "input2")
343+
intercept[IllegalArgumentException] {
344+
remover.transform(df).count()
345+
}
346+
}
220347
}

python/pyspark/ml/feature.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3774,9 +3774,13 @@ def setOutputCol(self, value):
37743774
return self._set(outputCol=value)
37753775

37763776

3777-
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
3777+
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
3778+
JavaMLReadable, JavaMLWritable):
37783779
"""
37793780
A feature transformer that filters out stop words from input.
3781+
Since 3.0.0, :py:class:`StopWordsRemover` can filter out multiple columns at once by setting
3782+
the :py:attr:`inputCols` parameter. Note that when both the :py:attr:`inputCol` and
3783+
:py:attr:`inputCols` parameters are set, an Exception will be thrown.
37803784
37813785
.. note:: null values from input array are preserved unless adding null to stopWords explicitly.
37823786
@@ -3795,6 +3799,17 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
37953799
True
37963800
>>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()
37973801
True
3802+
>>> df2 = spark.createDataFrame([(["a", "b", "c"], ["a", "b"])], ["text1", "text2"])
3803+
>>> remover2 = StopWordsRemover(stopWords=["b"])
3804+
>>> remover2.setInputCols(["text1", "text2"]).setOutputCols(["words1", "words2"])
3805+
StopWordsRemover...
3806+
>>> remover2.transform(df2).show()
3807+
+---------+------+------+------+
3808+
| text1| text2|words1|words2|
3809+
+---------+------+------+------+
3810+
|[a, b, c]|[a, b]|[a, c]| [a]|
3811+
+---------+------+------+------+
3812+
...
37983813
37993814
.. versionadded:: 1.6.0
38003815
"""
@@ -3808,10 +3823,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
38083823

38093824
@keyword_only
38103825
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
3811-
locale=None):
3826+
locale=None, inputCols=None, outputCols=None):
38123827
"""
38133828
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
3814-
locale=None)
3829+
locale=None, inputCols=None, outputCols=None)
38153830
"""
38163831
super(StopWordsRemover, self).__init__()
38173832
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
@@ -3824,10 +3839,10 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=
38243839
@keyword_only
38253840
@since("1.6.0")
38263841
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
3827-
locale=None):
3842+
locale=None, inputCols=None, outputCols=None):
38283843
"""
38293844
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
3830-
locale=None)
3845+
locale=None, inputCols=None, outputCols=None)
38313846
Sets params for this StopWordRemover.
38323847
"""
38333848
kwargs = self._input_kwargs
@@ -3887,6 +3902,20 @@ def setOutputCol(self, value):
38873902
"""
38883903
return self._set(outputCol=value)
38893904

3905+
@since("3.0.0")
3906+
def setInputCols(self, value):
3907+
"""
3908+
Sets the value of :py:attr:`inputCols`.
3909+
"""
3910+
return self._set(inputCols=value)
3911+
3912+
@since("3.0.0")
3913+
def setOutputCols(self, value):
3914+
"""
3915+
Sets the value of :py:attr:`outputCols`.
3916+
"""
3917+
return self._set(outputCols=value)
3918+
38903919
@staticmethod
38913920
@since("2.0.0")
38923921
def loadDefaultStopWords(language):

0 commit comments

Comments
 (0)