Skip to content

Commit d7ae36a

Browse files
mgaido91dongjoon-hyun
authored andcommitted
[SPARK-25538][SQL] Zero-out all bytes when writing decimal
## What changes were proposed in this pull request? In apache#20850 when writing non-null decimals, instead of zero-ing all the 16 allocated bytes, we zero-out only the padding bytes. Since we always allocate 16 bytes, if the number of bytes needed for a decimal is lower than 9, then this means that the bytes between 8 and 16 are not zero-ed. I see 2 solutions here: - we can zero-out all the bytes in advance as it was done before apache#20850 (safer solution IMHO); - we can allocate only the needed bytes (may be a bit more efficient in terms of memory used, but I have not investigated the feasibility of this option). Hence I propose here the first solution in order to fix the correctness issue. We can eventually switch to the second if we think is more efficient later. ## How was this patch tested? Running the test attached in the JIRA + added UT Closes apache#22602 from mgaido91/SPARK-25582. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 56741c3 commit d7ae36a

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,13 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
185185
// grow the global buffer before writing data.
186186
holder.grow(16);
187187

188+
// always zero-out the 16-byte buffer
189+
Platform.putLong(getBuffer(), cursor(), 0L);
190+
Platform.putLong(getBuffer(), cursor() + 8, 0L);
191+
188192
// Make sure Decimal object has the same scale as DecimalType.
189193
// Note that we may pass in null Decimal object to set null for it.
190194
if (input == null || !input.changePrecision(precision, scale)) {
191-
// zero-out the bytes
192-
Platform.putLong(getBuffer(), cursor(), 0L);
193-
Platform.putLong(getBuffer(), cursor() + 8, 0L);
194-
195195
BitSetMethods.set(getBuffer(), startingOffset, ordinal);
196196
// keep the offset for future update
197197
setOffsetAndSize(ordinal, 0);
@@ -200,8 +200,6 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
200200
final int numBytes = bytes.length;
201201
assert numBytes <= 16;
202202

203-
zeroOutPaddingBytes(numBytes);
204-
205203
// Write the bytes to the variable length portion.
206204
Platform.copyMemory(
207205
bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.codegen
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.types.Decimal
22+
23+
class UnsafeRowWriterSuite extends SparkFunSuite {
24+
25+
def checkDecimalSizeInBytes(decimal: Decimal, numBytes: Int): Unit = {
26+
assert(decimal.toJavaBigDecimal.unscaledValue().toByteArray.length == numBytes)
27+
}
28+
29+
test("SPARK-25538: zero-out all bits for decimals") {
30+
val decimal1 = Decimal(0.431)
31+
decimal1.changePrecision(38, 18)
32+
checkDecimalSizeInBytes(decimal1, 8)
33+
34+
val decimal2 = Decimal(123456789.1232456789)
35+
decimal2.changePrecision(38, 18)
36+
checkDecimalSizeInBytes(decimal2, 11)
37+
// On an UnsafeRowWriter we write decimal2 first and then decimal1
38+
val unsafeRowWriter1 = new UnsafeRowWriter(1)
39+
unsafeRowWriter1.resetRowWriter()
40+
unsafeRowWriter1.write(0, decimal2, decimal2.precision, decimal2.scale)
41+
unsafeRowWriter1.reset()
42+
unsafeRowWriter1.write(0, decimal1, decimal1.precision, decimal1.scale)
43+
val res1 = unsafeRowWriter1.getRow
44+
// On a second UnsafeRowWriter we write directly decimal1
45+
val unsafeRowWriter2 = new UnsafeRowWriter(1)
46+
unsafeRowWriter2.resetRowWriter()
47+
unsafeRowWriter2.write(0, decimal1, decimal1.precision, decimal1.scale)
48+
val res2 = unsafeRowWriter2.getRow
49+
// The two rows should be the equal
50+
assert(res1 == res2)
51+
}
52+
53+
}

0 commit comments

Comments
 (0)