Skip to content

Commit 5c1f7c2

Browse files
committed
[SPARK-50934][ML][PYTHON][CONNECT] Support CountVectorizer and OneHotEncoder on Connect
### What changes were proposed in this pull request? Support CountVectorizer and OneHotEncoder on Connect ### Why are the changes needed? for feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49647 from zhengruifeng/ml_connect_cv. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 39e9b3b commit 5c1f7c2

File tree

6 files changed

+70
-1
lines changed

6 files changed

+70
-1
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,6 @@ org.apache.spark.ml.feature.VarianceThresholdSelector
5858
org.apache.spark.ml.feature.StringIndexer
5959
org.apache.spark.ml.feature.PCA
6060
org.apache.spark.ml.feature.Word2Vec
61+
org.apache.spark.ml.feature.CountVectorizer
62+
org.apache.spark.ml.feature.OneHotEncoder
63+

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,6 @@ org.apache.spark.ml.feature.VarianceThresholdSelectorModel
6363
org.apache.spark.ml.feature.StringIndexerModel
6464
org.apache.spark.ml.feature.PCAModel
6565
org.apache.spark.ml.feature.Word2VecModel
66+
org.apache.spark.ml.feature.CountVectorizerModel
67+
org.apache.spark.ml.feature.OneHotEncoderModel
68+

mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ class CountVectorizerModel(
277277

278278
import CountVectorizerModel._
279279

280+
private[ml] def this() = this(Identifiable.randomUID("cntVecModel"), Array.empty)
281+
280282
@Since("1.5.0")
281283
def this(vocabulary: Array[String]) = {
282284
this(Identifiable.randomUID("cntVecModel"), vocabulary)

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ class OneHotEncoderModel private[ml] (
234234

235235
import OneHotEncoderModel._
236236

237+
private[ml] def this() = this(Identifiable.randomUID("oneHotEncoder)"), Array.emptyIntArray)
238+
237239
// Returns the category size for each index with `dropLast` and `handleInvalid`
238240
// taken into account.
239241
private def getConfigedCategorySizes: Array[Int] = {

python/pyspark/ml/tests/test_feature.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
Bucketizer,
2929
CountVectorizer,
3030
CountVectorizerModel,
31+
OneHotEncoder,
32+
OneHotEncoderModel,
3133
HashingTF,
3234
IDF,
3335
NGram,
@@ -536,6 +538,61 @@ def test_word2vec(self):
536538
model2 = Word2VecModel.load(d)
537539
self.assertEqual(str(model), str(model2))
538540

541+
def test_count_vectorizer(self):
542+
df = self.spark.createDataFrame(
543+
[(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],
544+
["label", "raw"],
545+
)
546+
547+
cv = CountVectorizer()
548+
cv.setInputCol("raw")
549+
cv.setOutputCol("vectors")
550+
self.assertEqual(cv.getInputCol(), "raw")
551+
self.assertEqual(cv.getOutputCol(), "vectors")
552+
553+
model = cv.fit(df)
554+
self.assertEqual(sorted(model.vocabulary), ["a", "b", "c"])
555+
556+
output = model.transform(df)
557+
self.assertEqual(output.columns, ["label", "raw", "vectors"])
558+
self.assertEqual(output.count(), 2)
559+
560+
# save & load
561+
with tempfile.TemporaryDirectory(prefix="count_vectorizer") as d:
562+
cv.write().overwrite().save(d)
563+
cv2 = CountVectorizer.load(d)
564+
self.assertEqual(str(cv), str(cv2))
565+
566+
model.write().overwrite().save(d)
567+
model2 = CountVectorizerModel.load(d)
568+
self.assertEqual(str(model), str(model2))
569+
570+
def test_one_hot_encoder(self):
571+
df = self.spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])
572+
573+
encoder = OneHotEncoder()
574+
encoder.setInputCols(["input"])
575+
encoder.setOutputCols(["output"])
576+
self.assertEqual(encoder.getInputCols(), ["input"])
577+
self.assertEqual(encoder.getOutputCols(), ["output"])
578+
579+
model = encoder.fit(df)
580+
self.assertEqual(model.categorySizes, [3])
581+
582+
output = model.transform(df)
583+
self.assertEqual(output.columns, ["input", "output"])
584+
self.assertEqual(output.count(), 3)
585+
586+
# save & load
587+
with tempfile.TemporaryDirectory(prefix="count_vectorizer") as d:
588+
encoder.write().overwrite().save(d)
589+
encoder2 = OneHotEncoder.load(d)
590+
self.assertEqual(str(encoder), str(encoder2))
591+
592+
model.write().overwrite().save(d)
593+
model2 = OneHotEncoderModel.load(d)
594+
self.assertEqual(str(model), str(model2))
595+
539596
def test_tokenizer(self):
540597
df = self.spark.createDataFrame([("a b c",)], ["text"])
541598

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,9 @@ private[ml] object MLUtils {
589589
(classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")),
590590
(classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")),
591591
(classOf[PCAModel], Set("pc", "explainedVariance")),
592-
(classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")))
592+
(classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")),
593+
(classOf[CountVectorizerModel], Set("vocabulary")),
594+
(classOf[OneHotEncoderModel], Set("categorySizes")))
593595

594596
private def validate(obj: Any, method: String): Unit = {
595597
assert(obj != null)

0 commit comments

Comments
 (0)