Skip to content

Commit e887f05

Browse files
committed
[SPARK-50933][ML][PYTHON][CONNECT] Support Feature Selectors on Connect
### What changes were proposed in this pull request? Support Feature Selectors on Connect: - ChiSqSelector - UnivariateFeatureSelector - VarianceThresholdSelector ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes, new algorithms supported on connect ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49641 from zhengruifeng/ml_connect_selector. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 4e3b831 commit e887f05

File tree

7 files changed

+120
-0
lines changed

7 files changed

+120
-0
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ org.apache.spark.ml.feature.StandardScaler
5252
org.apache.spark.ml.feature.MaxAbsScaler
5353
org.apache.spark.ml.feature.MinMaxScaler
5454
org.apache.spark.ml.feature.RobustScaler
55+
org.apache.spark.ml.feature.ChiSqSelector
56+
org.apache.spark.ml.feature.UnivariateFeatureSelector
57+
org.apache.spark.ml.feature.VarianceThresholdSelector
5558
org.apache.spark.ml.feature.StringIndexer
5659
org.apache.spark.ml.feature.PCA
5760
org.apache.spark.ml.feature.Word2Vec

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ org.apache.spark.ml.feature.StandardScalerModel
5656
org.apache.spark.ml.feature.MaxAbsScalerModel
5757
org.apache.spark.ml.feature.MinMaxScalerModel
5858
org.apache.spark.ml.feature.RobustScalerModel
59+
org.apache.spark.ml.feature.ChiSqSelectorModel
60+
org.apache.spark.ml.feature.UnivariateFeatureSelectorModel
61+
org.apache.spark.ml.feature.VarianceThresholdSelectorModel
5962
org.apache.spark.ml.feature.StringIndexerModel
6063
org.apache.spark.ml.feature.PCAModel
6164
org.apache.spark.ml.feature.Word2VecModel

mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ final class ChiSqSelectorModel private[ml] (
137137

138138
import ChiSqSelectorModel._
139139

140+
private[ml] def this() = this(
141+
Identifiable.randomUID("chiSqSelector"), Array.emptyIntArray)
142+
140143
override protected def isNumericAttribute = false
141144

142145
/** @group setParam */

mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ class UnivariateFeatureSelectorModel private[ml](
289289
extends Model[UnivariateFeatureSelectorModel] with UnivariateFeatureSelectorParams
290290
with MLWritable {
291291

292+
private[ml] def this() = this(
293+
Identifiable.randomUID("UnivariateFeatureSelector"), Array.emptyIntArray)
294+
292295
/** @group setParam */
293296
@Since("3.1.1")
294297
def setFeaturesCol(value: String): this.type = set(featuresCol, value)

mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ class VarianceThresholdSelectorModel private[ml](
126126
extends Model[VarianceThresholdSelectorModel] with VarianceThresholdSelectorParams
127127
with MLWritable {
128128

129+
private[ml] def this() = this(
130+
Identifiable.randomUID("VarianceThresholdSelector"), Array.emptyIntArray)
131+
129132
if (selectedFeatures.length >= 2) {
130133
require(selectedFeatures.sliding(2).forall(l => l(0) < l(1)),
131134
"Index should be strictly increasing.")

python/pyspark/ml/tests/test_feature.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@
4242
MinMaxScalerModel,
4343
RobustScaler,
4444
RobustScalerModel,
45+
ChiSqSelector,
46+
ChiSqSelectorModel,
47+
UnivariateFeatureSelector,
48+
UnivariateFeatureSelectorModel,
49+
VarianceThresholdSelector,
50+
VarianceThresholdSelectorModel,
4551
StopWordsRemover,
4652
StringIndexer,
4753
StringIndexerModel,
@@ -391,6 +397,102 @@ def test_robust_scaler(self):
391397
self.assertEqual(str(model), str(model2))
392398
self.assertEqual(model2.getOutputCol(), "scaled")
393399

400+
def test_chi_sq_selector(self):
401+
df = self.spark.createDataFrame(
402+
[
403+
(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),
404+
(Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),
405+
(Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0),
406+
],
407+
["features", "label"],
408+
)
409+
410+
selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")
411+
self.assertEqual(selector.getNumTopFeatures(), 1)
412+
self.assertEqual(selector.getOutputCol(), "selectedFeatures")
413+
414+
model = selector.fit(df)
415+
self.assertEqual(model.selectedFeatures, [2])
416+
417+
output = model.transform(df)
418+
self.assertEqual(output.columns, ["features", "label", "selectedFeatures"])
419+
self.assertEqual(output.count(), 3)
420+
421+
# save & load
422+
with tempfile.TemporaryDirectory(prefix="chi_sq_selector") as d:
423+
selector.write().overwrite().save(d)
424+
selector2 = ChiSqSelector.load(d)
425+
self.assertEqual(str(selector), str(selector2))
426+
427+
model.write().overwrite().save(d)
428+
model2 = ChiSqSelectorModel.load(d)
429+
self.assertEqual(str(model), str(model2))
430+
431+
def test_univariate_selector(self):
432+
df = self.spark.createDataFrame(
433+
[
434+
(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),
435+
(Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),
436+
(Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0),
437+
],
438+
["features", "label"],
439+
)
440+
441+
selector = UnivariateFeatureSelector(outputCol="selectedFeatures")
442+
selector.setFeatureType("continuous").setLabelType("categorical").setSelectionThreshold(1)
443+
self.assertEqual(selector.getFeatureType(), "continuous")
444+
self.assertEqual(selector.getLabelType(), "categorical")
445+
self.assertEqual(selector.getOutputCol(), "selectedFeatures")
446+
self.assertEqual(selector.getSelectionThreshold(), 1)
447+
448+
model = selector.fit(df)
449+
self.assertEqual(model.selectedFeatures, [3])
450+
451+
output = model.transform(df)
452+
self.assertEqual(output.columns, ["features", "label", "selectedFeatures"])
453+
self.assertEqual(output.count(), 3)
454+
455+
# save & load
456+
with tempfile.TemporaryDirectory(prefix="univariate_selector") as d:
457+
selector.write().overwrite().save(d)
458+
selector2 = UnivariateFeatureSelector.load(d)
459+
self.assertEqual(str(selector), str(selector2))
460+
461+
model.write().overwrite().save(d)
462+
model2 = UnivariateFeatureSelectorModel.load(d)
463+
self.assertEqual(str(model), str(model2))
464+
465+
def test_variance_threshold_selector(self):
466+
df = self.spark.createDataFrame(
467+
[
468+
(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),
469+
(Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),
470+
(Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0),
471+
],
472+
["features", "label"],
473+
)
474+
475+
selector = VarianceThresholdSelector(varianceThreshold=2, outputCol="selectedFeatures")
476+
self.assertEqual(selector.getVarianceThreshold(), 2)
477+
self.assertEqual(selector.getOutputCol(), "selectedFeatures")
478+
479+
model = selector.fit(df)
480+
self.assertEqual(model.selectedFeatures, [2])
481+
482+
output = model.transform(df)
483+
self.assertEqual(output.columns, ["features", "label", "selectedFeatures"])
484+
self.assertEqual(output.count(), 3)
485+
486+
# save & load
487+
with tempfile.TemporaryDirectory(prefix="variance_threshold_selector") as d:
488+
selector.write().overwrite().save(d)
489+
selector2 = VarianceThresholdSelector.load(d)
490+
self.assertEqual(str(selector), str(selector2))
491+
492+
model.write().overwrite().save(d)
493+
model2 = VarianceThresholdSelectorModel.load(d)
494+
self.assertEqual(str(model), str(model2))
495+
394496
def test_word2vec(self):
395497
sent = ("a b " * 100 + "a c " * 10).split(" ")
396498
df = self.spark.createDataFrame([(sent,), (sent,)], ["sentence"]).coalesce(1)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,9 @@ private[ml] object MLUtils {
589589
(classOf[MaxAbsScalerModel], Set("maxAbs")),
590590
(classOf[MinMaxScalerModel], Set("originalMax", "originalMin")),
591591
(classOf[RobustScalerModel], Set("range", "median")),
592+
(classOf[ChiSqSelectorModel], Set("selectedFeatures")),
593+
(classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")),
594+
(classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")),
592595
(classOf[PCAModel], Set("pc", "explainedVariance")),
593596
(classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")))
594597

0 commit comments

Comments
 (0)