Skip to content

Commit 39e9b3b

Browse files
committed
[SPARK-50932][ML][PYTHON][CONNECT] Support Bucketizer on Connect
### What changes were proposed in this pull request? Support Bucketizer on Connect ### Why are the changes needed? For feature parity ### Does this PR introduce _any_ user-facing change? yes, new algorithm supported on connect ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49646 from zhengruifeng/ml_connect_bucketizer. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent c662441 commit 39e9b3b

File tree

3 files changed

+78
-12
lines changed

3 files changed

+78
-12
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
########### Transformers
2121
org.apache.spark.ml.feature.DCT
2222
org.apache.spark.ml.feature.Binarizer
23+
org.apache.spark.ml.feature.Bucketizer
2324
org.apache.spark.ml.feature.VectorAssembler
2425
org.apache.spark.ml.feature.Tokenizer
2526
org.apache.spark.ml.feature.RegexTokenizer

python/pyspark/ml/tests/test_feature.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pyspark.ml.feature import (
2626
DCT,
2727
Binarizer,
28+
Bucketizer,
2829
CountVectorizer,
2930
CountVectorizerModel,
3031
HashingTF,
@@ -688,17 +689,17 @@ def test_binarizer(self):
688689
["v1", "v2"],
689690
)
690691

691-
bucketizer = Binarizer(threshold=1.0, inputCol="v1", outputCol="f1")
692-
output = bucketizer.transform(df)
692+
binarizer = Binarizer(threshold=1.0, inputCol="v1", outputCol="f1")
693+
output = binarizer.transform(df)
693694
self.assertEqual(output.columns, ["v1", "v2", "f1"])
694695
self.assertEqual(output.count(), 6)
695696
self.assertEqual(
696697
[r.f1 for r in output.select("f1").collect()],
697698
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
698699
)
699700

700-
bucketizer = Binarizer(threshold=1.0, inputCols=["v1", "v2"], outputCols=["f1", "f2"])
701-
output = bucketizer.transform(df)
701+
binarizer = Binarizer(threshold=1.0, inputCols=["v1", "v2"], outputCols=["f1", "f2"])
702+
output = binarizer.transform(df)
702703
self.assertEqual(output.columns, ["v1", "v2", "f1", "f2"])
703704
self.assertEqual(output.count(), 6)
704705
self.assertEqual(
@@ -712,8 +713,74 @@ def test_binarizer(self):
712713

713714
# save & load
714715
with tempfile.TemporaryDirectory(prefix="binarizer") as d:
716+
binarizer.write().overwrite().save(d)
717+
binarizer2 = Binarizer.load(d)
718+
self.assertEqual(str(binarizer), str(binarizer2))
719+
720+
def test_bucketizer(self):
721+
df = self.spark.createDataFrame(
722+
[
723+
(0.1, 0.0),
724+
(0.4, 1.0),
725+
(1.2, 1.3),
726+
(1.5, float("nan")),
727+
(float("nan"), 1.0),
728+
(float("nan"), 0.0),
729+
],
730+
["v1", "v2"],
731+
)
732+
733+
splits = [-float("inf"), 0.5, 1.4, float("inf")]
734+
bucketizer = Bucketizer()
735+
bucketizer.setSplits(splits)
736+
bucketizer.setHandleInvalid("keep")
737+
bucketizer.setInputCol("v1")
738+
bucketizer.setOutputCol("b1")
739+
740+
self.assertEqual(bucketizer.getSplits(), splits)
741+
self.assertEqual(bucketizer.getHandleInvalid(), "keep")
742+
self.assertEqual(bucketizer.getInputCol(), "v1")
743+
self.assertEqual(bucketizer.getOutputCol(), "b1")
744+
745+
output = bucketizer.transform(df)
746+
self.assertEqual(output.columns, ["v1", "v2", "b1"])
747+
self.assertEqual(output.count(), 6)
748+
self.assertEqual(
749+
[r.b1 for r in output.select("b1").collect()],
750+
[0.0, 0.0, 1.0, 2.0, 3.0, 3.0],
751+
)
752+
753+
splitsArray = [
754+
[-float("inf"), 0.5, 1.4, float("inf")],
755+
[-float("inf"), 0.5, float("inf")],
756+
]
757+
bucketizer = Bucketizer(
758+
splitsArray=splitsArray,
759+
inputCols=["v1", "v2"],
760+
outputCols=["b1", "b2"],
761+
)
762+
bucketizer.setHandleInvalid("keep")
763+
self.assertEqual(bucketizer.getSplitsArray(), splitsArray)
764+
self.assertEqual(bucketizer.getHandleInvalid(), "keep")
765+
self.assertEqual(bucketizer.getInputCols(), ["v1", "v2"])
766+
self.assertEqual(bucketizer.getOutputCols(), ["b1", "b2"])
767+
768+
output = bucketizer.transform(df)
769+
self.assertEqual(output.columns, ["v1", "v2", "b1", "b2"])
770+
self.assertEqual(output.count(), 6)
771+
self.assertEqual(
772+
[r.b1 for r in output.select("b1").collect()],
773+
[0.0, 0.0, 1.0, 2.0, 3.0, 3.0],
774+
)
775+
self.assertEqual(
776+
[r.b2 for r in output.select("b2").collect()],
777+
[0.0, 1.0, 1.0, 2.0, 1.0, 0.0],
778+
)
779+
780+
# save & load
781+
with tempfile.TemporaryDirectory(prefix="bucketizer") as d:
715782
bucketizer.write().overwrite().save(d)
716-
bucketizer2 = Binarizer.load(d)
783+
bucketizer2 = Bucketizer.load(d)
717784
self.assertEqual(str(bucketizer), str(bucketizer2))
718785

719786
def test_idf(self):

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ private[ml] object MLUtils {
187187
array.map(_.asInstanceOf[Double])
188188
} else if (elementType == classOf[String]) {
189189
array.map(_.asInstanceOf[String])
190+
} else if (elementType.isArray && elementType.getComponentType == classOf[Double]) {
191+
array.map(_.asInstanceOf[Array[_]].map(_.asInstanceOf[Double]))
190192
} else {
191193
throw MlUnsupportedException(
192194
s"array element type unsupported, " +
@@ -228,14 +230,10 @@ private[ml] object MLUtils {
228230
value.asInstanceOf[String]
229231
} else if (paramType.isArray) {
230232
val compType = paramType.getComponentType
231-
if (compType.isArray) {
232-
throw MlUnsupportedException(s"Array of array unsupported")
233-
} else {
234-
val array = value.asInstanceOf[Array[_]].map { e =>
235-
reconcileParam(compType, e)
236-
}
237-
reconcileArray(compType, array)
233+
val array = value.asInstanceOf[Array[_]].map { e =>
234+
reconcileParam(compType, e)
238235
}
236+
reconcileArray(compType, array)
239237
} else {
240238
throw MlUnsupportedException(s"Unsupported parameter type, found ${paramType.getName}")
241239
}

0 commit comments

Comments
 (0)