Skip to content

Commit 45c5506

Browse files
committed
[spark] supports converting some SparkPredicate to Paimon between LeafPredicate
1 parent 8012ff7 commit 45c5506

File tree

3 files changed

+83
-3
lines changed

3 files changed

+83
-3
lines changed

paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkV2FilterConverter.scala

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818

1919
package org.apache.paimon.spark
2020

21-
import org.apache.paimon.predicate.{Predicate, PredicateBuilder, Transform}
21+
import org.apache.paimon.predicate.{GreaterOrEqual, LeafPredicate, LessOrEqual, Predicate, PredicateBuilder, Transform}
2222
import org.apache.paimon.spark.util.SparkExpressionConverter.{toPaimonLiteral, toPaimonTransform}
2323
import org.apache.paimon.types.RowType
2424

2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.sql.connector.expressions.{Expression, Literal}
2727
import org.apache.spark.sql.connector.expressions.filter.{And, Not, Or, Predicate => SparkPredicate}
2828

29+
import java.util.Objects
30+
2931
import scala.collection.JavaConverters._
3032

3133
/** Conversion from [[SparkPredicate]] to [[Predicate]]. */
@@ -125,7 +127,12 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging {
125127

126128
case AND =>
127129
val and = sparkPredicate.asInstanceOf[And]
128-
PredicateBuilder.and(convert(and.left), convert(and.right()))
130+
val leftPredicate = convert(and.left)
131+
val rightPredicate = convert(and.right())
132+
convertToBetweenFunction(leftPredicate, rightPredicate) match {
133+
case Some(predicate) => predicate
134+
case _ => PredicateBuilder.and(leftPredicate, rightPredicate)
135+
}
129136

130137
case OR =>
131138
val or = sparkPredicate.asInstanceOf[Or]
@@ -169,6 +176,42 @@ case class SparkV2FilterConverter(rowType: RowType) extends Logging {
169176
}
170177
}
171178

179+
private def convertToBetweenFunction(
180+
leftPredicate: Predicate,
181+
rightPredicate: Predicate): Option[Predicate] = {
182+
def toBetweenLeafPredicate(
183+
transform: Transform,
184+
lowerBoundInclusive: Object,
185+
upperBoundInclusive: Object): Predicate = {
186+
builder.between(transform, lowerBoundInclusive, upperBoundInclusive)
187+
}
188+
189+
(leftPredicate, rightPredicate) match {
190+
case (left: LeafPredicate, right: LeafPredicate) =>
191+
// left and right should have the same transform
192+
if (!Objects.equals(left.transform(), right.transform())) {
193+
return None
194+
}
195+
(left.function(), right.function()) match {
196+
case (_: GreaterOrEqual, _: LessOrEqual) =>
197+
Some(
198+
toBetweenLeafPredicate(
199+
left.transform(),
200+
left.literals().get(0),
201+
right.literals().get(0)))
202+
case (_: LessOrEqual, _: GreaterOrEqual) =>
203+
Some(
204+
toBetweenLeafPredicate(
205+
left.transform(),
206+
right.literals().get(0),
207+
left.literals().get(0)))
208+
case _ => None
209+
}
210+
case _ =>
211+
None
212+
}
213+
}
214+
172215
private object UnaryPredicate {
173216
def unapply(sparkPredicate: SparkPredicate): Option[Transform] = {
174217
sparkPredicate.children() match {

paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowIdPushDownTestBase.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ abstract class RowIdPushDownTestBase extends PaimonSparkTestBase {
6565
sql("SELECT * FROM t WHERE _ROW_ID IN (6, 7)"),
6666
Seq()
6767
)
68+
checkAnswer(
69+
sql("SELECT * FROM t WHERE _ROW_ID BETWEEN 0 AND 2"),
70+
Seq(Row(0, 0, "0"), Row(1, 1, "1"), Row(2, 2, "2"))
71+
)
6872

6973
// 2.CompoundPredicate
7074
checkAnswer(

paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/SparkV2FilterConverterTestBase.scala

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
package org.apache.paimon.spark.sql
2020

2121
import org.apache.paimon.data.{BinaryString, Decimal, Timestamp}
22-
import org.apache.paimon.predicate.PredicateBuilder
22+
import org.apache.paimon.predicate.{Between, LeafPredicate, PredicateBuilder}
2323
import org.apache.paimon.spark.{PaimonSparkTestBase, SparkV2FilterConverter}
2424
import org.apache.paimon.spark.util.shim.TypeUtils.treatPaimonTimestampTypeAsSparkTimestampType
2525
import org.apache.paimon.table.source.DataSplit
@@ -295,6 +295,39 @@ abstract class SparkV2FilterConverterTestBase extends PaimonSparkTestBase {
295295
assert(scanFilesCount(filter) == 3)
296296
}
297297

298+
test("V2Filter: Between") {
299+
// 1. test basic between
300+
val filter = "string_col BETWEEN 'a' AND 'r'"
301+
val actual = converter.convert(v2Filter(filter)).get
302+
assert(
303+
actual.equals(builder.between(0, BinaryString.fromString("a"), BinaryString.fromString("r"))))
304+
checkAnswer(
305+
sql(s"SELECT string_col from test_tbl WHERE $filter ORDER BY string_col"),
306+
Seq(Row("hello"), Row("hi"), Row("paimon"))
307+
)
308+
309+
// 2. >= and <= on same transform should also be converted to between
310+
val filter1 = "CONCAT(string_col, '_suffix') >= 'a' AND CONCAT(string_col, '_suffix') <= 'r'"
311+
val actual1 = converter.convert(v2Filter(filter1)).get
312+
assert(actual1.isInstanceOf[LeafPredicate])
313+
val function = actual1.asInstanceOf[LeafPredicate].function
314+
assert(function.isInstanceOf[Between])
315+
checkAnswer(
316+
sql(s"SELECT string_col from test_tbl WHERE $filter1 ORDER BY string_col"),
317+
Seq(Row("hello"), Row("hi"), Row("paimon"))
318+
)
319+
320+
// 3. >= and <= on different transform should not be converted to between
321+
val filter2 = "CONCAT(string_col, '_suffix1') >= 'a' AND CONCAT(string_col, '_suffix2') <= 'r'"
322+
val actual2 = converter.convert(v2Filter(filter2)).get
323+
assert(!actual2.isInstanceOf[LeafPredicate])
324+
325+
// 4. >= and <= on different columns should not be converted to between
326+
val filter3 = "string_col >= 'a' AND int_col <= 2"
327+
val actual3 = converter.convert(v2Filter(filter3)).get
328+
assert(!actual3.isInstanceOf[LeafPredicate])
329+
}
330+
298331
test("V2Filter: And") {
299332
val filter = "int_col > 1 AND int_col < 3"
300333
val actual = converter.convert(v2Filter(filter)).get

0 commit comments

Comments
 (0)