Skip to content

Commit 4c00d63

Browse files
wbo4958zhengruifeng
authored andcommitted
[SPARK-51004][ML][PYTHON][CONNECT] Add supports for IndexString
### What changes were proposed in this pull request? This PR add supports for IndexString and add labels/labelsArray to ALLOWED_LIST. ### Why are the changes needed? new feature parity and bug fix ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? CI passes ### Was this patch authored or co-authored using generative AI tooling? No Closes #49690 from wbo4958/index-str. Authored-by: Bobby Wang <wbo4958@gmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org> (cherry picked from commit b5deb8d) Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 9da7a26 commit 4c00d63

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ org.apache.spark.ml.feature.SQLTransformer
3232
org.apache.spark.ml.feature.StopWordsRemover
3333
org.apache.spark.ml.feature.FeatureHasher
3434
org.apache.spark.ml.feature.HashingTF
35+
org.apache.spark.ml.feature.IndexToString
3536

3637
########### Model for loading
3738
# classification

python/pyspark/ml/tests/test_feature.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
BucketedRandomProjectionLSHModel,
7373
MinHashLSH,
7474
MinHashLSHModel,
75+
IndexToString,
7576
)
7677
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
7778
from pyspark.sql import Row
@@ -80,6 +81,51 @@
8081

8182

8283
class FeatureTestsMixin:
84+
def test_index_string(self):
85+
dataset = self.spark.createDataFrame(
86+
[
87+
(0, "a"),
88+
(1, "b"),
89+
(2, "c"),
90+
(3, "a"),
91+
(4, "a"),
92+
(5, "c"),
93+
],
94+
["id", "label"],
95+
)
96+
97+
indexer = StringIndexer(inputCol="label", outputCol="labelIndex").fit(dataset)
98+
transformed = indexer.transform(dataset)
99+
idx2str = (
100+
IndexToString()
101+
.setInputCol("labelIndex")
102+
.setOutputCol("sameLabel")
103+
.setLabels(indexer.labels)
104+
)
105+
106+
def check(t: IndexToString) -> None:
107+
self.assertEqual(t.getInputCol(), "labelIndex")
108+
self.assertEqual(t.getOutputCol(), "sameLabel")
109+
self.assertEqual(t.getLabels(), indexer.labels)
110+
111+
check(idx2str)
112+
113+
ret = idx2str.transform(transformed)
114+
self.assertEqual(
115+
sorted(ret.schema.names), sorted(["id", "label", "labelIndex", "sameLabel"])
116+
)
117+
118+
rows = ret.select("label", "sameLabel").collect()
119+
for r in rows:
120+
self.assertEqual(r.label, r.sameLabel)
121+
122+
# save & load
123+
with tempfile.TemporaryDirectory(prefix="index_string") as d:
124+
idx2str.write().overwrite().save(d)
125+
idx2str2 = IndexToString.load(d)
126+
self.assertEqual(str(idx2str), str(idx2str2))
127+
check(idx2str2)
128+
83129
def test_dct(self):
84130
df = self.spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"])
85131
dct = DCT()
@@ -128,6 +174,7 @@ def test_string_indexer(self):
128174
si = StringIndexer(inputCol="label1", outputCol="index1")
129175
model = si.fit(df.select("label1"))
130176
self.assertEqual(si.uid, model.uid)
177+
self.assertEqual(model.labels, list(model.labelsArray[0]))
131178

132179
# read/write
133180
with tempfile.TemporaryDirectory(prefix="string_indexer") as tmp_dir:
@@ -188,6 +235,7 @@ def test_pca(self):
188235
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
189236

190237
model = pca.fit(df)
238+
self.assertTrue(np.allclose(model.pc.toArray()[0], [-0.44859172, -0.28423808], atol=1e-4))
191239
self.assertEqual(pca.uid, model.uid)
192240
self.assertEqual(model.getK(), 2)
193241
self.assertTrue(

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ private[ml] object MLUtils {
646646
(classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")),
647647
(classOf[CountVectorizerModel], Set("vocabulary")),
648648
(classOf[OneHotEncoderModel], Set("categorySizes")),
649+
(classOf[StringIndexerModel], Set("labels", "labelsArray")),
649650
(classOf[IDFModel], Set("idf", "docFreq", "numDocs")))
650651

651652
private def validate(obj: Any, method: String): Unit = {

0 commit comments

Comments
 (0)