diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionEvaluator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionEvaluator.java new file mode 100644 index 0000000000000..0bacc601931f0 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionEvaluator.java @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.util; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import scala.jdk.javaapi.CollectionConverters; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.InterpretedPredicate; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Utility class for evaluating an {@link InternalRow} against a data source V2 {@link Predicate}. + * + *

This class provides methods to translate DSV2 predicates into Catalyst expressions based on a + * given schema, and to evaluate these predicates against InternalRows. + * + * @since 4.1.0 + */ +public final class V2ExpressionEvaluator { + + /** + * Converts a Spark DataSourceV2 {@link Predicate} to a Catalyst {@link Expression}. + * + *

This method translates supported DSV2 predicates into their equivalent Catalyst expressions, + * using the provided schema for column resolution. Unsupported predicates, or those referencing + * unknown columns, will result in an empty Optional. + * + *

Supported predicates include: + * + *

+ * + * @param predicate the DSV2 Predicate to convert + * @param schema the schema used for resolving column references + * @return Catalyst Expression representing the converted predicate, or empty if the predicate is + * unsupported or references unknown columns + */ + public static Optional dsv2PredicateToCatalystExpression( + org.apache.spark.sql.connector.expressions.filter.Predicate predicate, StructType schema) { + String predicateName = predicate.name(); + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + + switch (predicateName) { + case "IS_NULL": + if (children.length == 1) { + Optional expressionOpt = + dsv2ExpressionToCatalystExpression(children[0], schema); + if (expressionOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.IsNull(expressionOpt.get())); + } + } + break; + + case "IS_NOT_NULL": + if (children.length == 1) { + Optional expressionOpt = + dsv2ExpressionToCatalystExpression(children[0], schema); + if (expressionOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.IsNotNull(expressionOpt.get())); + } + } + break; + + case "STARTS_WITH": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.StartsWith( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case "ENDS_WITH": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.EndsWith( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case "CONTAINS": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.Contains( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case "IN": + if (children.length >= 2) { + Optional firstOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + if (firstOpt.isPresent()) { + List values = new ArrayList<>(); + for (int i = 1; i < children.length; i++) { + Optional valueOpt = + dsv2ExpressionToCatalystExpression(children[i], schema); + if (valueOpt.isPresent()) { + values.add(valueOpt.get()); + } else { + // if any value in the IN list cannot be converted, return empty + return Optional.empty(); + } + } + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.In( + firstOpt.get(), CollectionConverters.asScala(values).toSeq() + )); + } + } + break; + + case "=": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.EqualTo( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case "<>": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.Not( + new org.apache.spark.sql.catalyst.expressions.EqualTo( + leftOpt.get(), rightOpt.get()))); + } + } + break; + + case "<=>": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.EqualNullSafe( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case "<": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.LessThan( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case "<=": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.LessThanOrEqual( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case ">": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.GreaterThan( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case ">=": + if (children.length == 2) { + Optional leftOpt = dsv2ExpressionToCatalystExpression(children[0], schema); + Optional rightOpt = dsv2ExpressionToCatalystExpression(children[1], schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual( + leftOpt.get(), rightOpt.get())); + } + } + break; + + case "AND": + if (children.length == 2) { + Optional leftOpt = + dsv2PredicateToCatalystExpression( + (org.apache.spark.sql.connector.expressions.filter.Predicate) + predicate.children()[0], + schema); + Optional rightOpt = + dsv2PredicateToCatalystExpression( + (org.apache.spark.sql.connector.expressions.filter.Predicate) + predicate.children()[1], + schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.And(leftOpt.get(), rightOpt.get())); + } + } + break; + + case "OR": + if (children.length == 2) { + Optional leftOpt = + dsv2PredicateToCatalystExpression( + (org.apache.spark.sql.connector.expressions.filter.Predicate) + predicate.children()[0], + schema); + Optional rightOpt = + dsv2PredicateToCatalystExpression( + (org.apache.spark.sql.connector.expressions.filter.Predicate) + predicate.children()[1], + schema); + if (leftOpt.isPresent() && rightOpt.isPresent()) { + return Optional.of( + new org.apache.spark.sql.catalyst.expressions.Or(leftOpt.get(), rightOpt.get())); + } + } + break; + + case "NOT": + if (children.length == 1) { + Optional childOpt = + dsv2PredicateToCatalystExpression( + (org.apache.spark.sql.connector.expressions.filter.Predicate) + predicate.children()[0], + schema); + if (childOpt.isPresent()) { + return Optional.of(new org.apache.spark.sql.catalyst.expressions.Not(childOpt.get())); + } + } + break; + + case "ALWAYS_TRUE": + if (children.length == 0) { + return Optional.of( + org.apache.spark.sql.catalyst.expressions.Literal.create( + true, org.apache.spark.sql.types.DataTypes.BooleanType)); + } + break; + + case "ALWAYS_FALSE": + if (children.length == 0) { + return Optional.of( + org.apache.spark.sql.catalyst.expressions.Literal.create( + false, org.apache.spark.sql.types.DataTypes.BooleanType)); + } + break; + } + + return Optional.empty(); + } + + /** + * Translate a DSV2 Expression to a Catalyst {@link Expression} using the provided schema. + * + *

This method handles NamedReference and LiteralValue expressions. NamedReferences are + * resolved to BoundReferences based on the schema, while LiteralValues are converted to Catalyst + * Literals. Unsupported expression types or references to unknown columns will result in an empty + * Optional. + * + * @param expr the DSV2 Expression to resolve + * @param schema the schema used for resolving column references + * @return Catalyst Expression representing the resolved expression, or empty if the expression is + * unsupported or references unknown columns + */ + public static Optional dsv2ExpressionToCatalystExpression( + org.apache.spark.sql.connector.expressions.Expression expr, StructType schema) { + if (expr instanceof NamedReference ref) { + String columnName = ref.fieldNames()[0]; + try { + int index = schema.fieldIndex(columnName); + StructField field = schema.fields()[index]; + return Optional.of(new BoundReference(index, field.dataType(), field.nullable())); + } catch (IllegalArgumentException e) { + // schema.fieldIndex(columnName) throws IllegalArgumentException if a field with the given + // name does not exist + return Optional.empty(); + } + } else if (expr instanceof LiteralValue literal) { + return Optional.of( + org.apache.spark.sql.catalyst.expressions.Literal.create( + literal.value(), literal.dataType())); + } else { + return Optional.empty(); + } + } + + + /** + * Evaluates a DSV2 {@link Predicate} on an {@link InternalRow} of the provided schema. + * + *

This method first converts the DSV2 Predicate to a Catalyst Expression using the provided + * schema. If the conversion is successful, it creates a Predicate evaluator and evaluates it + * against the given InternalRow. If the predicate cannot be converted, an empty Optional is + * returned. + * + * @param predicate the DSV2 Predicate to evaluate + * @param internalRow the InternalRow to evaluate the predicate against + * @param schema the schema used for resolving column references in the predicate + * @return Optional containing the result of the evaluation (true or false), or empty if the + * predicate could not be converted + */ + public static Optional evaluateInternalRowOnDsv2Predicate( + org.apache.spark.sql.connector.expressions.filter.Predicate predicate, + InternalRow internalRow, + StructType schema) { + Optional catalystExpr = dsv2PredicateToCatalystExpression(predicate, schema); + if (catalystExpr.isEmpty()) { + return Optional.empty(); + } + InterpretedPredicate evaluator = + org.apache.spark.sql.catalyst.expressions.Predicate.createInterpreted(catalystExpr.get()); + return Optional.of(evaluator.eval(internalRow)); + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/connector/util/V2ExpressionEvaluatorSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/connector/util/V2ExpressionEvaluatorSuite.java new file mode 100644 index 0000000000000..d87ffcce9643e --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/connector/util/V2ExpressionEvaluatorSuite.java @@ -0,0 +1,133 @@ +package org.apache.spark.sql.connector.util; + +import java.util.Optional; + +import org.junit.jupiter.api.Test; +import static org.apache.spark.sql.connector.util.V2ExpressionEvaluator.dsv2PredicateToCatalystExpression; +import static org.apache.spark.sql.connector.util.V2ExpressionEvaluator.evaluateInternalRowOnDsv2Predicate; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.NamedReference; + +public class V2ExpressionEvaluatorSuite { + + private final org.apache.spark.sql.types.StructType testSchema = + new org.apache.spark.sql.types.StructType() + .add("id", org.apache.spark.sql.types.DataTypes.IntegerType, false) + .add("name", org.apache.spark.sql.types.DataTypes.StringType, true) + .add("age", org.apache.spark.sql.types.DataTypes.IntegerType, true); + + private final InternalRow[] testData = new InternalRow[]{ + new GenericInternalRow(new Object[]{1, org.apache.spark.unsafe.types.UTF8String.fromString("Alice"), 30}), + new GenericInternalRow(new Object[]{2, org.apache.spark.unsafe.types.UTF8String.fromString("Bob"), null}), + new GenericInternalRow(new Object[]{3, null, 25}), + new GenericInternalRow(new Object[]{4, org.apache.spark.unsafe.types.UTF8String.fromString("David"), 35}) + }; + + private final NamedReference idRef = FieldReference.apply("id"); + private final NamedReference nameRef = FieldReference.apply("name"); + private final NamedReference ageRef = FieldReference.apply("age"); + + private final org.apache.spark.sql.connector.expressions.filter.Predicate isNullPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + "IS_NULL", new org.apache.spark.sql.connector.expressions.Expression[]{nameRef}); + private final org.apache.spark.sql.connector.expressions.filter.Predicate isNotNullPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + "IS_NOT_NULL", new org.apache.spark.sql.connector.expressions.Expression[]{nameRef}); + private final org.apache.spark.sql.connector.expressions.filter.Predicate inPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + "IN", + new org.apache.spark.sql.connector.expressions.Expression[]{ + idRef, + LiteralValue.apply(1, org.apache.spark.sql.types.DataTypes.IntegerType), + LiteralValue.apply(3, org.apache.spark.sql.types.DataTypes.IntegerType)}); + private final org.apache.spark.sql.connector.expressions.filter.Predicate equalsPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + "=", + new org.apache.spark.sql.connector.expressions.Expression[]{ + idRef, + LiteralValue.apply(2, org.apache.spark.sql.types.DataTypes.IntegerType)}); + private final org.apache.spark.sql.connector.expressions.filter.Predicate greaterThanPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + ">", + new org.apache.spark.sql.connector.expressions.Expression[]{ + ageRef, + LiteralValue.apply(20, org.apache.spark.sql.types.DataTypes.IntegerType)}); + private final org.apache.spark.sql.connector.expressions.filter.Predicate notPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + "NOT", + new org.apache.spark.sql.connector.expressions.Expression[]{isNullPredicate}); + private final org.apache.spark.sql.connector.expressions.filter.Predicate andPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + "AND", + new org.apache.spark.sql.connector.expressions.Expression[]{isNotNullPredicate, greaterThanPredicate}); + private final org.apache.spark.sql.connector.expressions.filter.Predicate orPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + "OR", + new org.apache.spark.sql.connector.expressions.Expression[]{isNullPredicate, equalsPredicate}); + private final org.apache.spark.sql.connector.expressions.filter.Predicate unsupportedPredicate = + new org.apache.spark.sql.connector.expressions.filter.Predicate( + "UNSUPPORTED_OP", + new org.apache.spark.sql.connector.expressions.Expression[]{idRef}); + + @Test + public void testDsv2PredicateToCatalystExpression() { + // Null tests + checkExpressionConversionAndEvaluation(isNullPredicate, true, new Boolean[]{false, false, true, false}); + checkExpressionConversionAndEvaluation(isNotNullPredicate, true, new Boolean[]{true, true, false, true}); + + // IN operator + checkExpressionConversionAndEvaluation(inPredicate, true, new Boolean[]{true, false, true, false}); + + // Comparison operators + checkExpressionConversionAndEvaluation(equalsPredicate, true, new Boolean[]{false, true, false, false}); + checkExpressionConversionAndEvaluation(greaterThanPredicate, true, new Boolean[]{true, false, true, true}); + + // Logical operators + checkExpressionConversionAndEvaluation(notPredicate, true, new Boolean[]{true, true, false, true}); + // AND: name IS NOT NULL AND age > 20 + // Row 0: Alice, 30 -> true AND true = true + // Row 1: Bob, null -> true AND false = false + // Row 2: null, 25 -> false AND true = false + // Row 3: David, 35 -> true AND true = true + checkExpressionConversionAndEvaluation(andPredicate, true, new Boolean[]{true, false, false, true}); + // OR: name IS NULL OR id = 2 + // Row 0: false OR false = false + // Row 1: false OR true = true + // Row 2: true OR false = true + // Row 3: false OR false = false + checkExpressionConversionAndEvaluation(orPredicate, true, new Boolean[]{false, true, true, false}); + + // Unsupported predicate + checkExpressionConversionAndEvaluation(unsupportedPredicate, false, null); + } + + private void checkExpressionConversionAndEvaluation( + org.apache.spark.sql.connector.expressions.filter.Predicate predicate, + boolean isConvertible, + Boolean[] expectedResults) { + Optional catalystExpr = dsv2PredicateToCatalystExpression(predicate, testSchema); + + if (isConvertible) { + assertTrue(catalystExpr.isPresent(), "Predicate should be convertible"); + for (int i = 0; i < testData.length; i++) { + Optional evalResult = evaluateInternalRowOnDsv2Predicate(predicate, testData[i], testSchema); + assertTrue(evalResult.isPresent(), "Evaluation result should be present"); + assertEquals(evalResult.get(), expectedResults[i], String.format("Row %d: expected %s but got %s", i, expectedResults[i], evalResult.get())); + } + } else { + assertTrue(catalystExpr.isEmpty(), "Predicate should not be convertible"); + for (InternalRow row : testData) { + Optional evalResult = evaluateInternalRowOnDsv2Predicate(predicate, row, testSchema); + assertTrue(evalResult.isEmpty(), "Evaluation result should not be present"); + } + } + } + +}