Skip to content

Commit acdda8a

Browse files
committed
[SPARK-50989][ML][PYTHON][CONNECT] Support NGram, Normalizer and Interaction on connect
### What changes were proposed in this pull request? Support NGram, Normalizer and Interaction on connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49668 from zhengruifeng/ml_connect_ngram. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 4af6be6 commit acdda8a

File tree

3 files changed

+63
-11
lines changed

3 files changed

+63
-11
lines changed

mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
########### Transformers
2222
org.apache.spark.ml.feature.DCT
23+
org.apache.spark.ml.feature.NGram
24+
org.apache.spark.ml.feature.Normalizer
25+
org.apache.spark.ml.feature.Interaction
2326
org.apache.spark.ml.feature.Binarizer
2427
org.apache.spark.ml.feature.Bucketizer
2528
org.apache.spark.ml.feature.VectorAssembler

python/pyspark/ml/tests/connect/test_parity_feature.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222

2323

2424
class FeatureParityTests(FeatureTestsMixin, ReusedConnectTestCase):
25-
@unittest.skip("Need to support.")
26-
def test_ngram(self):
27-
super().test_ngram()
28-
2925
@unittest.skip("Need to support.")
3026
def test_count_vectorizer_with_binary(self):
3127
super().test_count_vectorizer_with_binary()

python/pyspark/ml/tests/test_feature.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
Imputer,
3838
ImputerModel,
3939
NGram,
40+
Normalizer,
41+
Interaction,
4042
RFormula,
4143
Tokenizer,
4244
SQLTransformer,
@@ -923,13 +925,64 @@ def test_idf(self):
923925
self.assertEqual(str(model), str(model2))
924926

925927
def test_ngram(self):
926-
dataset = self.spark.createDataFrame([Row(input=["a", "b", "c", "d", "e"])])
927-
ngram0 = NGram(n=4, inputCol="input", outputCol="output")
928-
self.assertEqual(ngram0.getN(), 4)
929-
self.assertEqual(ngram0.getInputCol(), "input")
930-
self.assertEqual(ngram0.getOutputCol(), "output")
931-
transformedDF = ngram0.transform(dataset)
932-
self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"])
928+
spark = self.spark
929+
df = spark.createDataFrame([Row(input=["a", "b", "c", "d", "e"])])
930+
931+
ngram = NGram(n=4, inputCol="input", outputCol="output")
932+
self.assertEqual(ngram.getN(), 4)
933+
self.assertEqual(ngram.getInputCol(), "input")
934+
self.assertEqual(ngram.getOutputCol(), "output")
935+
936+
output = ngram.transform(df)
937+
self.assertEqual(output.head().output, ["a b c d", "b c d e"])
938+
939+
# save & load
940+
with tempfile.TemporaryDirectory(prefix="ngram") as d:
941+
ngram.write().overwrite().save(d)
942+
ngram2 = NGram.load(d)
943+
self.assertEqual(str(ngram), str(ngram2))
944+
945+
def test_normalizer(self):
946+
spark = self.spark
947+
df = spark.createDataFrame(
948+
[(Vectors.dense([3.0, -4.0]),), (Vectors.sparse(4, {1: 4.0, 3: 3.0}),)],
949+
["input"],
950+
)
951+
952+
normalizer = Normalizer(p=2.0, inputCol="input", outputCol="output")
953+
self.assertEqual(normalizer.getP(), 2.0)
954+
self.assertEqual(normalizer.getInputCol(), "input")
955+
self.assertEqual(normalizer.getOutputCol(), "output")
956+
957+
output = normalizer.transform(df)
958+
self.assertEqual(output.columns, ["input", "output"])
959+
self.assertEqual(output.count(), 2)
960+
961+
# save & load
962+
with tempfile.TemporaryDirectory(prefix="normalizer") as d:
963+
normalizer.write().overwrite().save(d)
964+
normalizer2 = Normalizer.load(d)
965+
self.assertEqual(str(normalizer), str(normalizer2))
966+
967+
def test_interaction(self):
968+
spark = self.spark
969+
df = spark.createDataFrame([(0.0, 1.0), (2.0, 3.0)], ["a", "b"])
970+
971+
interaction = Interaction()
972+
interaction.setInputCols(["a", "b"])
973+
interaction.setOutputCol("ab")
974+
self.assertEqual(interaction.getInputCols(), ["a", "b"])
975+
self.assertEqual(interaction.getOutputCol(), "ab")
976+
977+
output = interaction.transform(df)
978+
self.assertEqual(output.columns, ["a", "b", "ab"])
979+
self.assertEqual(output.count(), 2)
980+
981+
# save & load
982+
with tempfile.TemporaryDirectory(prefix="interaction") as d:
983+
interaction.write().overwrite().save(d)
984+
interaction2 = Interaction.load(d)
985+
self.assertEqual(str(interaction), str(interaction2))
933986

934987
def test_count_vectorizer_with_binary(self):
935988
dataset = self.spark.createDataFrame(

0 commit comments

Comments
 (0)