Skip to content

Commit a96acea

Browse files
authored
[GLUTEN-11402][VL] Fix decimal partition key serialization to preserve scale (#11618)
This PR fixes decimal partition value serialization by replacing toJavaBigInteger.toString with toJavaBigDecimal.unscaledValue().toString, removes fallback guard that was added by #11518 and adds additional test cases to SQLQuerySuite covering small decimals, zero-scale decimals, negative values, and multi-partition pruning.
1 parent d3fabc2 commit a96acea

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class VeloxIteratorApi extends IteratorApi with Logging {
161161
case _: DateType =>
162162
DateFormatter.apply().format(pv.asInstanceOf[Integer])
163163
case _: DecimalType =>
164-
pv.asInstanceOf[Decimal].toJavaBigInteger.toString
164+
pv.asInstanceOf[Decimal].toJavaBigDecimal.unscaledValue().toString
165165
case _: TimestampType =>
166166
TimestampFormatter
167167
.getFractionFormatter(ZoneOffset.UTC)

gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
2727

2828
import org.apache.spark.Partition
2929
import org.apache.spark.sql.catalyst.expressions._
30-
import org.apache.spark.sql.types.DecimalType
3130

3231
import com.google.protobuf.StringValue
3332
import io.substrait.proto.NamedStruct
@@ -139,10 +138,6 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource
139138
return validationResult
140139
}
141140

142-
if (getPartitionSchema.fields.exists(_.dataType.isInstanceOf[DecimalType])) {
143-
return ValidationResult.failed(s"Unsupported decimal partition column in native scan.")
144-
}
145-
146141
val substraitContext = new SubstraitContext
147142
val relNode = transform(substraitContext).root
148143

gluten-ut/test/src/test/scala/org/apache/gluten/sql/SQLQuerySuite.scala

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class SQLQuerySuite extends WholeStageTransformerSuite {
5858
val df = spark.createDataFrame(data).toDF("key", "value")
5959
df.createOrReplaceTempView("src")
6060

61-
// decimal
61+
// decimal with fractional truncation
6262
sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))")
6363
sql("""
6464
|insert into table dynparttest2 partition(pdec)
@@ -68,6 +68,62 @@ class SQLQuerySuite extends WholeStageTransformerSuite {
6868
sql("select * from dynparttest2"),
6969
Seq(Row(6, new java.math.BigDecimal("100.1"))))
7070
}
71+
72+
// small decimal with scale > 0
73+
withTable("dynparttest_small") {
74+
sql("create table dynparttest_small (value int) partitioned by (pdec decimal(3, 2))")
75+
sql("""
76+
|insert into table dynparttest_small partition(pdec)
77+
| select count(*), cast('1.23' as decimal(3, 2)) as pdec from src
78+
""".stripMargin)
79+
checkAnswer(
80+
sql("select * from dynparttest_small"),
81+
Seq(Row(6, new java.math.BigDecimal("1.23"))))
82+
}
83+
84+
// zero scale with no fractional part
85+
withTable("dynparttest_zero_scale") {
86+
sql("create table dynparttest_zero_scale (value int) partitioned by (pdec decimal(10, 0))")
87+
sql("""
88+
|insert into table dynparttest_zero_scale partition(pdec)
89+
| select count(*), cast('42' as decimal(10, 0)) as pdec from src
90+
""".stripMargin)
91+
checkAnswer(
92+
sql("select * from dynparttest_zero_scale"),
93+
Seq(Row(6, new java.math.BigDecimal("42"))))
94+
}
95+
96+
// negative value with scale
97+
withTable("dynparttest_neg") {
98+
sql("create table dynparttest_neg (value int) partitioned by (pdec decimal(5, 2))")
99+
sql("""
100+
|insert into table dynparttest_neg partition(pdec)
101+
| select count(*), cast('-3.14' as decimal(5, 2)) as pdec from src
102+
""".stripMargin)
103+
checkAnswer(
104+
sql("select * from dynparttest_neg"),
105+
Seq(Row(6, new java.math.BigDecimal("-3.14"))))
106+
}
107+
108+
// multiple distinct partition values
109+
withTable("dynparttest_multi") {
110+
sql("create table dynparttest_multi (value int) partitioned by (pdec decimal(4, 1))")
111+
sql("""
112+
|insert into table dynparttest_multi partition(pdec)
113+
| select count(*), cast('10.5' as decimal(4, 1)) as pdec from src
114+
""".stripMargin)
115+
sql("""
116+
|insert into table dynparttest_multi partition(pdec)
117+
| select count(*), cast('20.3' as decimal(4, 1)) as pdec from src
118+
""".stripMargin)
119+
checkAnswer(
120+
sql("select * from dynparttest_multi order by pdec"),
121+
Seq(Row(6, new java.math.BigDecimal("10.5")), Row(6, new java.math.BigDecimal("20.3"))))
122+
// partition pruning
123+
checkAnswer(
124+
sql("select * from dynparttest_multi where pdec = 10.5"),
125+
Seq(Row(6, new java.math.BigDecimal("10.5"))))
126+
}
71127
}
72128
}
73129

0 commit comments

Comments
 (0)