Skip to content

Commit 8a307d1

Browse files
kiszkRobert Kruszewski
authored andcommitted
[SPARK-23713][SQL] Cleanup UnsafeWriter and BufferHolder classes
## What changes were proposed in this pull request? This PR implemented the following cleanups related to `UnsafeWriter` class: - Remove code duplication between `UnsafeRowWriter` and `UnsafeArrayWriter` - Make `BufferHolder` class internal by delegating its accessor methods to `UnsafeWriter` - Replace `UnsafeRow.setTotalSize(...)` with `UnsafeRowWriter.setTotalSize()` ## How was this patch tested? Tested by existing UTs Author: Kazuaki Ishizaki <[email protected]> Closes apache#20850 from kiszk/SPARK-23713.
1 parent 08f64b4 commit 8a307d1

File tree

12 files changed

+391
-408
lines changed

12 files changed

+391
-408
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,10 @@ import org.apache.spark.TaskContext
2727
import org.apache.spark.internal.Logging
2828
import org.apache.spark.sql.SparkSession
2929
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
30-
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
31-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3230
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
3331
import org.apache.spark.sql.sources.v2.reader._
3432
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
3533
import org.apache.spark.sql.types.StructType
36-
import org.apache.spark.unsafe.types.UTF8String
3734

3835
/**
3936
* A [[ContinuousReader]] for data from kafka.

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,16 @@ package org.apache.spark.sql.kafka010
2020
import org.apache.kafka.clients.consumer.ConsumerRecord
2121

2222
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
23-
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
2424
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2525
import org.apache.spark.unsafe.types.UTF8String
2626

2727
/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */
2828
private[kafka010] class KafkaRecordToUnsafeRowConverter {
29-
private val sharedRow = new UnsafeRow(7)
30-
private val bufferHolder = new BufferHolder(sharedRow)
31-
private val rowWriter = new UnsafeRowWriter(bufferHolder, 7)
29+
private val rowWriter = new UnsafeRowWriter(7)
3230

3331
def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = {
34-
bufferHolder.reset()
32+
rowWriter.reset()
3533

3634
if (record.key == null) {
3735
rowWriter.setNullAt(0)
@@ -46,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter {
4644
5,
4745
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp)))
4846
rowWriter.write(6, record.timestampType.id)
49-
sharedRow.setTotalSize(bufferHolder.totalSize)
50-
sharedRow
47+
rowWriter.getRow()
5148
}
5249
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,21 @@
3030
* this class per writing program, so that the memory segment/data buffer can be reused. Note that
3131
* for each incoming record, we should call `reset` of BufferHolder instance before write the record
3232
* and reuse the data buffer.
33-
*
34-
* Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update
35-
* the size of the result row, after writing a record to the buffer. However, we can skip this step
36-
* if the fields of row are all fixed-length, as the size of result row is also fixed.
3733
*/
38-
public class BufferHolder {
34+
final class BufferHolder {
3935

4036
private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
4137

42-
public byte[] buffer;
43-
public int cursor = Platform.BYTE_ARRAY_OFFSET;
38+
private byte[] buffer;
39+
private int cursor = Platform.BYTE_ARRAY_OFFSET;
4440
private final UnsafeRow row;
4541
private final int fixedSize;
4642

47-
public BufferHolder(UnsafeRow row) {
43+
BufferHolder(UnsafeRow row) {
4844
this(row, 64);
4945
}
5046

51-
public BufferHolder(UnsafeRow row, int initialSize) {
47+
BufferHolder(UnsafeRow row, int initialSize) {
5248
int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields());
5349
if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) {
5450
throw new UnsupportedOperationException(
@@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) {
6460
/**
6561
* Grows the buffer by at least neededSize and points the row to the buffer.
6662
*/
67-
public void grow(int neededSize) {
63+
void grow(int neededSize) {
6864
if (neededSize > ARRAY_MAX - totalSize()) {
6965
throw new UnsupportedOperationException(
7066
"Cannot grow BufferHolder by size " + neededSize + " because the size after growing " +
@@ -86,11 +82,23 @@ public void grow(int neededSize) {
8682
}
8783
}
8884

89-
public void reset() {
85+
byte[] getBuffer() {
86+
return buffer;
87+
}
88+
89+
int getCursor() {
90+
return cursor;
91+
}
92+
93+
void increaseCursor(int val) {
94+
cursor += val;
95+
}
96+
97+
void reset() {
9098
cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize;
9199
}
92100

93-
public int totalSize() {
101+
int totalSize() {
94102
return cursor - Platform.BYTE_ARRAY_OFFSET;
95103
}
96104
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java

Lines changed: 31 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
import org.apache.spark.unsafe.Platform;
2222
import org.apache.spark.unsafe.array.ByteArrayMethods;
2323
import org.apache.spark.unsafe.bitset.BitSetMethods;
24-
import org.apache.spark.unsafe.types.CalendarInterval;
25-
import org.apache.spark.unsafe.types.UTF8String;
2624

2725
import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
2826

@@ -32,141 +30,123 @@
3230
*/
3331
public final class UnsafeArrayWriter extends UnsafeWriter {
3432

35-
private BufferHolder holder;
36-
37-
// The offset of the global buffer where we start to write this array.
38-
private int startingOffset;
39-
4033
// The number of elements in this array
4134
private int numElements;
4235

36+
// The element size in this array
37+
private int elementSize;
38+
4339
private int headerInBytes;
4440

4541
private void assertIndexIsValid(int index) {
4642
assert index >= 0 : "index (" + index + ") should >= 0";
4743
assert index < numElements : "index (" + index + ") should < " + numElements;
4844
}
4945

50-
public void initialize(BufferHolder holder, int numElements, int elementSize) {
46+
public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) {
47+
super(writer.getBufferHolder());
48+
this.elementSize = elementSize;
49+
}
50+
51+
public void initialize(int numElements) {
5152
// We need 8 bytes to store numElements in header
5253
this.numElements = numElements;
5354
this.headerInBytes = calculateHeaderPortionInBytes(numElements);
5455

55-
this.holder = holder;
56-
this.startingOffset = holder.cursor;
56+
this.startingOffset = cursor();
5757

5858
// Grows the global buffer ahead for header and fixed size data.
5959
int fixedPartInBytes =
6060
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementSize * numElements);
6161
holder.grow(headerInBytes + fixedPartInBytes);
6262

6363
// Write numElements and clear out null bits to header
64-
Platform.putLong(holder.buffer, startingOffset, numElements);
64+
Platform.putLong(getBuffer(), startingOffset, numElements);
6565
for (int i = 8; i < headerInBytes; i += 8) {
66-
Platform.putLong(holder.buffer, startingOffset + i, 0L);
66+
Platform.putLong(getBuffer(), startingOffset + i, 0L);
6767
}
6868

6969
// fill 0 into reminder part of 8-bytes alignment in unsafe array
7070
for (int i = elementSize * numElements; i < fixedPartInBytes; i++) {
71-
Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0);
71+
Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0);
7272
}
73-
holder.cursor += (headerInBytes + fixedPartInBytes);
73+
increaseCursor(headerInBytes + fixedPartInBytes);
7474
}
7575

76-
private void zeroOutPaddingBytes(int numBytes) {
77-
if ((numBytes & 0x07) > 0) {
78-
Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
79-
}
80-
}
81-
82-
private long getElementOffset(int ordinal, int elementSize) {
76+
private long getElementOffset(int ordinal) {
8377
return startingOffset + headerInBytes + ordinal * elementSize;
8478
}
8579

86-
public void setOffsetAndSize(int ordinal, int currentCursor, int size) {
87-
assertIndexIsValid(ordinal);
88-
final long relativeOffset = currentCursor - startingOffset;
89-
final long offsetAndSize = (relativeOffset << 32) | (long)size;
90-
91-
write(ordinal, offsetAndSize);
92-
}
93-
9480
private void setNullBit(int ordinal) {
9581
assertIndexIsValid(ordinal);
96-
BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
82+
BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal);
9783
}
9884

9985
public void setNull1Bytes(int ordinal) {
10086
setNullBit(ordinal);
10187
// put zero into the corresponding field when set null
102-
Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
88+
writeByte(getElementOffset(ordinal), (byte)0);
10389
}
10490

10591
public void setNull2Bytes(int ordinal) {
10692
setNullBit(ordinal);
10793
// put zero into the corresponding field when set null
108-
Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
94+
writeShort(getElementOffset(ordinal), (short)0);
10995
}
11096

11197
public void setNull4Bytes(int ordinal) {
11298
setNullBit(ordinal);
11399
// put zero into the corresponding field when set null
114-
Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0);
100+
writeInt(getElementOffset(ordinal), 0);
115101
}
116102

117103
public void setNull8Bytes(int ordinal) {
118104
setNullBit(ordinal);
119105
// put zero into the corresponding field when set null
120-
Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
106+
writeLong(getElementOffset(ordinal), 0);
121107
}
122108

123109
public void setNull(int ordinal) { setNull8Bytes(ordinal); }
124110

125111
public void write(int ordinal, boolean value) {
126112
assertIndexIsValid(ordinal);
127-
Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value);
113+
writeBoolean(getElementOffset(ordinal), value);
128114
}
129115

130116
public void write(int ordinal, byte value) {
131117
assertIndexIsValid(ordinal);
132-
Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value);
118+
writeByte(getElementOffset(ordinal), value);
133119
}
134120

135121
public void write(int ordinal, short value) {
136122
assertIndexIsValid(ordinal);
137-
Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value);
123+
writeShort(getElementOffset(ordinal), value);
138124
}
139125

140126
public void write(int ordinal, int value) {
141127
assertIndexIsValid(ordinal);
142-
Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value);
128+
writeInt(getElementOffset(ordinal), value);
143129
}
144130

145131
public void write(int ordinal, long value) {
146132
assertIndexIsValid(ordinal);
147-
Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value);
133+
writeLong(getElementOffset(ordinal), value);
148134
}
149135

150136
public void write(int ordinal, float value) {
151-
if (Float.isNaN(value)) {
152-
value = Float.NaN;
153-
}
154137
assertIndexIsValid(ordinal);
155-
Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value);
138+
writeFloat(getElementOffset(ordinal), value);
156139
}
157140

158141
public void write(int ordinal, double value) {
159-
if (Double.isNaN(value)) {
160-
value = Double.NaN;
161-
}
162142
assertIndexIsValid(ordinal);
163-
Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value);
143+
writeDouble(getElementOffset(ordinal), value);
164144
}
165145

166146
public void write(int ordinal, Decimal input, int precision, int scale) {
167147
// make sure Decimal object has the same scale as DecimalType
168148
assertIndexIsValid(ordinal);
169-
if (input.changePrecision(precision, scale)) {
149+
if (input != null && input.changePrecision(precision, scale)) {
170150
if (precision <= Decimal.MAX_LONG_DIGITS()) {
171151
write(ordinal, input.toUnscaledLong());
172152
} else {
@@ -180,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
180160

181161
// Write the bytes to the variable length portion.
182162
Platform.copyMemory(
183-
bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
184-
setOffsetAndSize(ordinal, holder.cursor, numBytes);
163+
bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);
164+
setOffsetAndSize(ordinal, numBytes);
185165

186166
// move the cursor forward with 8-bytes boundary
187-
holder.cursor += roundedSize;
167+
increaseCursor(roundedSize);
188168
}
189169
} else {
190170
setNull(ordinal);
191171
}
192172
}
193-
194-
public void write(int ordinal, UTF8String input) {
195-
final int numBytes = input.numBytes();
196-
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
197-
198-
// grow the global buffer before writing data.
199-
holder.grow(roundedSize);
200-
201-
zeroOutPaddingBytes(numBytes);
202-
203-
// Write the bytes to the variable length portion.
204-
input.writeToMemory(holder.buffer, holder.cursor);
205-
206-
setOffsetAndSize(ordinal, holder.cursor, numBytes);
207-
208-
// move the cursor forward.
209-
holder.cursor += roundedSize;
210-
}
211-
212-
public void write(int ordinal, byte[] input) {
213-
final int numBytes = input.length;
214-
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
215-
216-
// grow the global buffer before writing data.
217-
holder.grow(roundedSize);
218-
219-
zeroOutPaddingBytes(numBytes);
220-
221-
// Write the bytes to the variable length portion.
222-
Platform.copyMemory(
223-
input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
224-
225-
setOffsetAndSize(ordinal, holder.cursor, numBytes);
226-
227-
// move the cursor forward.
228-
holder.cursor += roundedSize;
229-
}
230-
231-
public void write(int ordinal, CalendarInterval input) {
232-
// grow the global buffer before writing data.
233-
holder.grow(16);
234-
235-
// Write the months and microseconds fields of Interval to the variable length portion.
236-
Platform.putLong(holder.buffer, holder.cursor, input.months);
237-
Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
238-
239-
setOffsetAndSize(ordinal, holder.cursor, 16);
240-
241-
// move the cursor forward.
242-
holder.cursor += 16;
243-
}
244173
}

0 commit comments

Comments
 (0)