|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package test.org.apache.spark.sql; |
| 19 | + |
| 20 | +import org.apache.spark.api.java.function.FilterFunction; |
| 21 | +import org.apache.spark.sql.Column; |
| 22 | +import org.apache.spark.sql.Dataset; |
| 23 | +import org.apache.spark.sql.Row; |
| 24 | +import org.apache.spark.sql.RowFactory; |
| 25 | +import org.apache.spark.sql.test.TestSparkSession; |
| 26 | +import org.apache.spark.sql.types.StructType; |
| 27 | +import org.junit.After; |
| 28 | +import org.junit.Assert; |
| 29 | +import org.junit.Before; |
| 30 | +import org.junit.Test; |
| 31 | + |
| 32 | +import java.util.*; |
| 33 | + |
| 34 | +import static org.apache.spark.sql.types.DataTypes.*; |
| 35 | + |
| 36 | +public class JavaColumnExpressionSuite { |
| 37 | + private transient TestSparkSession spark; |
| 38 | + |
| 39 | + @Before |
| 40 | + public void setUp() { |
| 41 | + spark = new TestSparkSession(); |
| 42 | + } |
| 43 | + |
| 44 | + @After |
| 45 | + public void tearDown() { |
| 46 | + spark.stop(); |
| 47 | + spark = null; |
| 48 | + } |
| 49 | + |
| 50 | + @Test |
| 51 | + public void isInCollectionWorksCorrectlyOnJava() { |
| 52 | + List<Row> rows = Arrays.asList( |
| 53 | + RowFactory.create(1, "x"), |
| 54 | + RowFactory.create(2, "y"), |
| 55 | + RowFactory.create(3, "z")); |
| 56 | + StructType schema = createStructType(Arrays.asList( |
| 57 | + createStructField("a", IntegerType, false), |
| 58 | + createStructField("b", StringType, false))); |
| 59 | + Dataset<Row> df = spark.createDataFrame(rows, schema); |
| 60 | + // Test with different types of collections |
| 61 | + Assert.assertTrue(Arrays.equals( |
| 62 | + (Row[]) df.filter(df.col("a").isInCollection(Arrays.asList(1, 2))).collect(), |
| 63 | + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() |
| 64 | + )); |
| 65 | + Assert.assertTrue(Arrays.equals( |
| 66 | + (Row[]) df.filter(df.col("a").isInCollection(new HashSet<>(Arrays.asList(1, 2)))).collect(), |
| 67 | + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 1 || r.getInt(0) == 2).collect() |
| 68 | + )); |
| 69 | + Assert.assertTrue(Arrays.equals( |
| 70 | + (Row[]) df.filter(df.col("a").isInCollection(new ArrayList<>(Arrays.asList(3, 1)))).collect(), |
| 71 | + (Row[]) df.filter((FilterFunction<Row>) r -> r.getInt(0) == 3 || r.getInt(0) == 1).collect() |
| 72 | + )); |
| 73 | + } |
| 74 | + |
| 75 | + @Test |
| 76 | + public void isInCollectionCheckExceptionMessage() { |
| 77 | + List<Row> rows = Arrays.asList( |
| 78 | + RowFactory.create(1, Arrays.asList(1)), |
| 79 | + RowFactory.create(2, Arrays.asList(2)), |
| 80 | + RowFactory.create(3, Arrays.asList(3))); |
| 81 | + StructType schema = createStructType(Arrays.asList( |
| 82 | + createStructField("a", IntegerType, false), |
| 83 | + createStructField("b", createArrayType(IntegerType, false), false))); |
| 84 | + Dataset<Row> df = spark.createDataFrame(rows, schema); |
| 85 | + try { |
| 86 | + df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))); |
| 87 | + Assert.fail("Expected org.apache.spark.sql.AnalysisException"); |
| 88 | + } catch (Exception e) { |
| 89 | + Arrays.asList("cannot resolve", |
| 90 | + "due to data type mismatch: Arguments must be same type but were") |
| 91 | + .forEach(s -> Assert.assertTrue( |
| 92 | + e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))); |
| 93 | + } |
| 94 | + } |
| 95 | +} |
0 commit comments