Skip to content

Commit e964947

Browse files
authored
fix: Fall back to Spark when hashing decimals with precision > 18 (#1325)
* fall back to Spark when hashing decimals with precision > 18 * murmur3 checks * refactor * fix * address feedback
1 parent 07274e8 commit e964947

File tree

3 files changed

+127
-42
lines changed

3 files changed

+127
-42
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,35 +2176,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
21762176
None
21772177
}
21782178

2179-
case Murmur3Hash(children, seed) =>
2180-
val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType))
2181-
if (firstUnSupportedInput.isDefined) {
2182-
withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}")
2183-
return None
2184-
}
2185-
val exprs = children.map(exprToProtoInternal(_, inputs, binding))
2186-
val seedBuilder = ExprOuterClass.Literal
2187-
.newBuilder()
2188-
.setDatatype(serializeDataType(IntegerType).get)
2189-
.setIntVal(seed)
2190-
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
2191-
// the seed is put at the end of the arguments
2192-
scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*)
2193-
2194-
case XxHash64(children, seed) =>
2195-
val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType))
2196-
if (firstUnSupportedInput.isDefined) {
2197-
withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}")
2198-
return None
2199-
}
2200-
val exprs = children.map(exprToProtoInternal(_, inputs, binding))
2201-
val seedBuilder = ExprOuterClass.Literal
2202-
.newBuilder()
2203-
.setDatatype(serializeDataType(LongType).get)
2204-
.setLongVal(seed)
2205-
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
2206-
// the seed is put at the end of the arguments
2207-
scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*)
2179+
case _: Murmur3Hash => CometMurmur3Hash.convert(expr, inputs, binding)
2180+
2181+
case _: XxHash64 => CometXxHash64.convert(expr, inputs, binding)
22082182

22092183
case Sha2(left, numBits) =>
22102184
if (!numBits.foldable) {
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, XxHash64}
23+
import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType}
24+
25+
import org.apache.comet.CometSparkSessionExtensions.withInfo
26+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, scalarExprToProtoWithReturnType, serializeDataType, supportedDataType}
27+
28+
object CometXxHash64 extends CometExpressionSerde {
29+
override def convert(
30+
expr: Expression,
31+
inputs: Seq[Attribute],
32+
binding: Boolean): Option[ExprOuterClass.Expr] = {
33+
if (!HashUtils.isSupportedType(expr)) {
34+
return None
35+
}
36+
val hash = expr.asInstanceOf[XxHash64]
37+
val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
38+
val seedBuilder = ExprOuterClass.Literal
39+
.newBuilder()
40+
.setDatatype(serializeDataType(LongType).get)
41+
.setLongVal(hash.seed)
42+
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
43+
// the seed is put at the end of the arguments
44+
scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*)
45+
}
46+
}
47+
48+
object CometMurmur3Hash extends CometExpressionSerde {
49+
override def convert(
50+
expr: Expression,
51+
inputs: Seq[Attribute],
52+
binding: Boolean): Option[ExprOuterClass.Expr] = {
53+
if (!HashUtils.isSupportedType(expr)) {
54+
return None
55+
}
56+
val hash = expr.asInstanceOf[Murmur3Hash]
57+
val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
58+
val seedBuilder = ExprOuterClass.Literal
59+
.newBuilder()
60+
.setDatatype(serializeDataType(IntegerType).get)
61+
.setIntVal(hash.seed)
62+
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
63+
// the seed is put at the end of the arguments
64+
scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*)
65+
}
66+
}
67+
68+
private object HashUtils {
69+
def isSupportedType(expr: Expression): Boolean = {
70+
for (child <- expr.children) {
71+
child.dataType match {
72+
case dt: DecimalType if dt.precision > 18 =>
73+
// Spark converts decimals with precision > 18 into
74+
// Java BigDecimal before hashing
75+
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
76+
return false
77+
case dt if !supportedDataType(dt) =>
78+
withInfo(expr, s"Unsupported datatype $dt")
79+
return false
80+
case _ =>
81+
}
82+
}
83+
true
84+
}
85+
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,19 +1929,45 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
19291929
}
19301930
}
19311931

1932-
test("hash functions with decimal input") {
1933-
withTable("t1", "t2") {
1934-
// Apache Spark: if it's a small decimal, i.e. precision <= 18, turn it into long and hash it.
1935-
// Else, turn it into bytes and hash it.
1936-
sql("create table t1(c1 decimal(18, 2)) using parquet")
1937-
sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
1938-
checkSparkAnswerAndOperator("select c1, hash(c1), xxhash64(c1) from t1 order by c1")
1939-
1940-
// TODO: comet hash function is not compatible with spark for decimal with precision greater than 18.
1941-
// https://github.com/apache/datafusion-comet/issues/1294
1942-
// sql("create table t2(c1 decimal(20, 2)) using parquet")
1943-
// sql("insert into t2 values(1.23), (-1.23), (0.0), (null)")
1944-
// checkSparkAnswerAndOperator("select c1, hash(c1), xxhash64(c1) from t2 order by c1")
1932+
test("hash function with decimal input") {
1933+
val testPrecisionScales: Seq[(Int, Int)] = Seq(
1934+
(1, 0),
1935+
(17, 2),
1936+
(18, 2),
1937+
(19, 2),
1938+
(DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE - 1))
1939+
for ((p, s) <- testPrecisionScales) {
1940+
withTable("t1") {
1941+
sql(s"create table t1(c1 decimal($p, $s)) using parquet")
1942+
sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
1943+
if (p <= 18) {
1944+
checkSparkAnswerAndOperator("select c1, hash(c1) from t1 order by c1")
1945+
} else {
1946+
// not supported natively yet
1947+
checkSparkAnswer("select c1, hash(c1) from t1 order by c1")
1948+
}
1949+
}
1950+
}
1951+
}
1952+
1953+
test("xxhash64 function with decimal input") {
1954+
val testPrecisionScales: Seq[(Int, Int)] = Seq(
1955+
(1, 0),
1956+
(17, 2),
1957+
(18, 2),
1958+
(19, 2),
1959+
(DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE - 1))
1960+
for ((p, s) <- testPrecisionScales) {
1961+
withTable("t1") {
1962+
sql(s"create table t1(c1 decimal($p, $s)) using parquet")
1963+
sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
1964+
if (p <= 18) {
1965+
checkSparkAnswerAndOperator("select c1, xxhash64(c1) from t1 order by c1")
1966+
} else {
1967+
// not supported natively yet
1968+
checkSparkAnswer("select c1, xxhash64(c1) from t1 order by c1")
1969+
}
1970+
}
19451971
}
19461972
}
19471973

0 commit comments

Comments
 (0)