Skip to content

Commit c685b5f

Browse files
aai95dbtsai
authored andcommitted
[SPARK-24411][SQL] Adding native Java tests for 'isInCollection'
## What changes were proposed in this pull request? `JavaColumnExpressionSuite.java` was added and `org.apache.spark.sql.ColumnExpressionSuite#test("isInCollection: Java Collection")` was removed. It provides native Java tests for the method `org.apache.spark.sql.Column#isInCollection`. Closes apache#22253 from aai95/isInCollectionJavaTest. Authored-by: aai95 <[email protected]> Signed-off-by: DB Tsai <[email protected]>
1 parent 135ff16 commit c685b5f

File tree

2 files changed

+95
-21
lines changed

2 files changed

+95
-21
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package test.org.apache.spark.sql;
19+
20+
import org.apache.spark.api.java.function.FilterFunction;
21+
import org.apache.spark.sql.Column;
22+
import org.apache.spark.sql.Dataset;
23+
import org.apache.spark.sql.Row;
24+
import org.apache.spark.sql.RowFactory;
25+
import org.apache.spark.sql.test.TestSparkSession;
26+
import org.apache.spark.sql.types.StructType;
27+
import org.junit.After;
28+
import org.junit.Assert;
29+
import org.junit.Before;
30+
import org.junit.Test;
31+
32+
import java.util.*;
33+
34+
import static org.apache.spark.sql.types.DataTypes.*;
35+
36+
public class JavaColumnExpressionSuite {
37+
private transient TestSparkSession spark;
38+
39+
@Before
40+
public void setUp() {
41+
spark = new TestSparkSession();
42+
}
43+
44+
@After
45+
public void tearDown() {
46+
spark.stop();
47+
spark = null;
48+
}
49+
50+
@Test
51+
public void isInCollectionWorksCorrectlyOnJava() {
52+
List<Row> rows = Arrays.asList(
53+
RowFactory.create(1, "x"),
54+
RowFactory.create(2, "y"),
55+
RowFactory.create(3, "z"));
56+
StructType schema = createStructType(Arrays.asList(
57+
createStructField("a", IntegerType, false),
58+
createStructField("b", StringType, false)));
59+
Dataset<Row> df = spark.createDataFrame(rows, schema);
60+
// Test with different types of collections
61+
Assert.assertTrue(Arrays.equals(
62+
(Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(),
63+
(Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect()
64+
));
65+
Assert.assertTrue(Arrays.equals(
66+
(Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(),
67+
(Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect()
68+
));
69+
Assert.assertTrue(Arrays.equals(
70+
(Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(),
71+
(Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect()
72+
));
73+
}
74+
75+
@Test
76+
public void isInCollectionCheckExceptionMessage() {
77+
List<Row> rows = Arrays.asList(
78+
RowFactory.create(1, Arrays.asList(1)),
79+
RowFactory.create(2, Arrays.asList(2)),
80+
RowFactory.create(3, Arrays.asList(3)));
81+
StructType schema = createStructType(Arrays.asList(
82+
createStructField("a", IntegerType, false),
83+
createStructField("b", createArrayType(IntegerType, false), false)));
84+
Dataset<Row> df = spark.createDataFrame(rows, schema);
85+
try {
86+
df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b"))));
87+
Assert.fail("Expected org.apache.spark.sql.AnalysisException");
88+
} catch (Exception e) {
89+
Arrays.asList("cannot resolve",
90+
"due to data type mismatch: Arguments must be same type but were")
91+
.forEach(s -> Assert.assertTrue(
92+
e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))));
93+
}
94+
}
95+
}

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -436,27 +436,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
436436
}
437437
}
438438

439-
test("isInCollection: Java Collection") {
440-
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
441-
// Test with different types of collections
442-
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)),
443-
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
444-
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)),
445-
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
446-
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)),
447-
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
448-
449-
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
450-
451-
val e = intercept[AnalysisException] {
452-
df2.filter($"a".isInCollection(Seq($"b").asJava))
453-
}
454-
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
455-
.foreach { s =>
456-
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
457-
}
458-
}
459-
460439
test("&&") {
461440
checkAnswer(
462441
booleanData.filter($"a" && true),

0 commit comments

Comments
 (0)