Skip to content

Commit d36edda

Browse files
committed
[SPARK-51105][ML][PYTHON][CONNECT][TESTS] Add parity test for ml functions
### What changes were proposed in this pull request? Add parity test for ml functions ### Why are the changes needed? for test coverage ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#49824 from zhengruifeng/ml_connect_f_ut. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 9dbbb5b commit d36edda

File tree

3 files changed

+102
-8
lines changed

3 files changed

+102
-8
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ def __hash__(self):
11271127
"pyspark.ml.tests.connect.test_parity_clustering",
11281128
"pyspark.ml.tests.connect.test_parity_evaluation",
11291129
"pyspark.ml.tests.connect.test_parity_feature",
1130+
"pyspark.ml.tests.connect.test_parity_functions",
11301131
"pyspark.ml.tests.connect.test_parity_pipeline",
11311132
"pyspark.ml.tests.connect.test_parity_tuning",
11321133
"pyspark.ml.tests.connect.test_parity_ovr",
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
20+
from pyspark.ml.tests.test_functions import (
21+
ArrayVectorConversionTestsMixin,
22+
PredictBatchUDFTestsMixin,
23+
)
24+
from pyspark.testing.connectutils import ReusedConnectTestCase
25+
from pyspark.testing.sqlutils import (
26+
have_pandas,
27+
have_pyarrow,
28+
pandas_requirement_message,
29+
pyarrow_requirement_message,
30+
)
31+
32+
33+
class ArrayVectorConversionParityTests(ArrayVectorConversionTestsMixin, ReusedConnectTestCase):
34+
pass
35+
36+
37+
@unittest.skipIf(
38+
not have_pandas or not have_pyarrow,
39+
pandas_requirement_message or pyarrow_requirement_message,
40+
)
41+
class PredictBatchUDFParityTests(PredictBatchUDFTestsMixin, ReusedConnectTestCase):
42+
pass
43+
44+
45+
if __name__ == "__main__":
46+
from pyspark.ml.tests.connect.test_parity_functions import * # noqa: F401
47+
48+
try:
49+
import xmlrunner # type: ignore[import]
50+
51+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
52+
except ImportError:
53+
testRunner = None
54+
unittest.main(testRunner=testRunner, verbosity=2)

python/pyspark/ml/tests/test_functions.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,58 @@
1919
import numpy as np
2020

2121
from pyspark.loose_version import LooseVersion
22-
from pyspark.ml.functions import predict_batch_udf
22+
from pyspark.ml.linalg import DenseVector
23+
from pyspark.ml.functions import array_to_vector, vector_to_array, predict_batch_udf
2324
from pyspark.sql.functions import array, struct, col
2425
from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType, StructField, FloatType
25-
from pyspark.testing.mlutils import SparkSessionTestCase
2626
from pyspark.testing.sqlutils import (
2727
have_pandas,
2828
have_pyarrow,
2929
pandas_requirement_message,
3030
pyarrow_requirement_message,
31+
ReusedSQLTestCase,
3132
)
3233

3334

34-
@unittest.skipIf(
35-
not have_pandas or not have_pyarrow,
36-
pandas_requirement_message or pyarrow_requirement_message,
37-
)
38-
class PredictBatchUDFTests(SparkSessionTestCase):
35+
class ArrayVectorConversionTestsMixin:
36+
def test_array_vector_conversion(self):
37+
spark = self.spark
38+
39+
query = """
40+
SELECT * FROM VALUES
41+
(1, ARRAY(1.0, 2.0, 3.0)),
42+
(1, ARRAY(-1.0, -2.0, -3.0))
43+
AS tab(a, b)
44+
"""
45+
46+
df = spark.sql(query)
47+
48+
df1 = df.select("*", array_to_vector(df.b).alias("c"))
49+
self.assertEqual(df1.columns, ["a", "b", "c"])
50+
self.assertEqual(df1.count(), 2)
51+
self.assertEqual(
52+
[r.c for r in df1.select("c").collect()],
53+
[DenseVector([1.0, 2.0, 3.0]), DenseVector([-1.0, -2.0, -3.0])],
54+
)
55+
56+
df2 = df1.select("*", vector_to_array(df1.c).alias("d"))
57+
self.assertEqual(df2.columns, ["a", "b", "c", "d"])
58+
self.assertEqual(df2.count(), 2)
59+
self.assertEqual(
60+
[r.d for r in df2.select("d").collect()],
61+
[[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]],
62+
)
63+
64+
65+
class ArrayVectorConversionTests(ArrayVectorConversionTestsMixin, ReusedSQLTestCase):
66+
pass
67+
68+
69+
class PredictBatchUDFTestsMixin:
3970
def setUp(self):
4071
import pandas as pd
4172

42-
super(PredictBatchUDFTests, self).setUp()
73+
super(PredictBatchUDFTestsMixin, self).setUp()
4374
self.data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4)
4475

4576
# 4 scalar columns
@@ -533,6 +564,14 @@ def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
533564
self.assertEqual(value, 9.0)
534565

535566

567+
@unittest.skipIf(
568+
not have_pandas or not have_pyarrow,
569+
pandas_requirement_message or pyarrow_requirement_message,
570+
)
571+
class PredictBatchUDFTests(PredictBatchUDFTestsMixin, ReusedSQLTestCase):
572+
pass
573+
574+
536575
if __name__ == "__main__":
537576
from pyspark.ml.tests.test_functions import * # noqa: F401
538577

0 commit comments

Comments
 (0)