Skip to content

Commit 9b1162e

Browse files
committed
address comments
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
1 parent 572c0da commit 9b1162e

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
com.nvidia.spark.rapids.SequenceFileBinaryFileFormat
2+

sql-plugin/src/main/scala/com/nvidia/spark/rapids/sequencefile/GpuSequenceFileReaders.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ private[sequencefile] final class HostBinaryListBufferer(
7171

7272
private def growOffsetsIfNeeded(): Unit = {
7373
if (numRows + 1 > rowsAllocated) {
74-
val newRowsAllocated = math.min(rowsAllocated.toLong * 2, Int.MaxValue.toLong - 1L).toInt
74+
// Use Int.MaxValue - 2 to ensure (rowsAllocated + 1) * 4 doesn't overflow
75+
val newRowsAllocated = math.min(rowsAllocated.toLong * 2, Int.MaxValue.toLong - 2L).toInt
7576
val newSize = (newRowsAllocated.toLong + 1L) * DType.INT32.getSizeInBytes
7677
closeOnExcept(HostMemoryBuffer.allocate(newSize)) { tmpBuffer =>
7778
tmpBuffer.copyFromHostBuffer(0, offsetsBuffer, 0, offsetsBuffer.getLength)
@@ -89,6 +90,8 @@ private[sequencefile] final class HostBinaryListBufferer(
8990
newBuff.copyFromHostBuffer(0, dataBuffer, 0, dataLocation)
9091
dataBuffer.close()
9192
dataBuffer = newBuff
93+
// Clear old stream wrapper before creating new ones
94+
dos = null
9295
out = new HostMemoryOutputStream(dataBuffer)
9396
dos = new DataOutputStream(out)
9497
}
@@ -123,7 +126,13 @@ private[sequencefile] final class HostBinaryListBufferer(
123126
val offsetPosition = numRows.toLong * DType.INT32.getSizeInBytes
124127
val startDataLocation = dataLocation
125128
out.seek(dataLocation)
129+
val startPos = out.getPos
126130
valueBytes.writeUncompressedBytes(dos)
131+
val actualLen = (out.getPos - startPos).toInt
132+
if (actualLen != len) {
133+
throw new IllegalStateException(
134+
s"addValueBytes length mismatch: expected $len bytes, but wrote $actualLen bytes")
135+
}
127136
dataLocation = out.getPos
128137
// Write offset only after successful data write
129138
offsetsBuffer.setInt(offsetPosition, startDataLocation.toInt)
@@ -534,23 +543,26 @@ case class GpuSequenceFileMultiFilePartitionReaderFactory(
534543

535544
override protected def getFileFormatShortName: String = "SequenceFileBinary"
536545

537-
override protected def buildBaseColumnarReaderForCloud(
546+
private def buildSequenceFileMultiFileReader(
538547
files: Array[PartitionedFile],
539548
conf: Configuration): PartitionReader[ColumnarBatch] = {
540-
// No special cloud implementation yet; read sequentially on the task thread.
541549
new PartitionReaderWithBytesRead(
542550
new SequenceFileMultiFilePartitionReader(conf, files, readDataSchema, partitionSchema,
543551
maxReadBatchSizeRows, maxReadBatchSizeBytes, maxGpuColumnSizeBytes,
544552
metrics, queryUsesInputFile))
545553
}
546554

555+
override protected def buildBaseColumnarReaderForCloud(
556+
files: Array[PartitionedFile],
557+
conf: Configuration): PartitionReader[ColumnarBatch] = {
558+
// No special cloud implementation yet; read sequentially on the task thread.
559+
buildSequenceFileMultiFileReader(files, conf)
560+
}
561+
547562
override protected def buildBaseColumnarReaderForCoalescing(
548563
files: Array[PartitionedFile],
549564
conf: Configuration): PartitionReader[ColumnarBatch] = {
550565
// Sequential multi-file reader (no cross-file coalescing).
551-
new PartitionReaderWithBytesRead(
552-
new SequenceFileMultiFilePartitionReader(conf, files, readDataSchema, partitionSchema,
553-
maxReadBatchSizeRows, maxReadBatchSizeBytes, maxGpuColumnSizeBytes,
554-
metrics, queryUsesInputFile))
566+
buildSequenceFileMultiFileReader(files, conf)
555567
}
556568
}

tests/src/test/scala/com/nvidia/spark/rapids/SequenceFileBinaryFileFormatSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ import org.scalatest.funsuite.AnyFunSuite
3232
import org.apache.spark.SparkException
3333
import org.apache.spark.sql.SparkSession
3434

35+
/**
36+
* Unit tests for SequenceFileBinaryFileFormat.
37+
*
38+
* Note: This test suite uses its own withSparkSession/withGpuSparkSession methods instead of
39+
* extending SparkQueryCompareTestSuite because:
40+
* 1. These tests need fresh SparkSession instances per test to avoid state pollution
41+
* 2. The tests don't need the compare-CPU-vs-GPU pattern from SparkQueryCompareTestSuite
42+
* 3. The simpler session management makes the tests more self-contained
43+
*/
3544
class SequenceFileBinaryFileFormatSuite extends AnyFunSuite {
3645

3746
private def withSparkSession(f: SparkSession => Unit): Unit = {
@@ -56,7 +65,7 @@ class SequenceFileBinaryFileFormatSuite extends AnyFunSuite {
5665
.config("spark.sql.shuffle.partitions", "1")
5766
.config("spark.plugins", "com.nvidia.spark.SQLPlugin")
5867
.config("spark.rapids.sql.enabled", "true")
59-
.config("spark.rapids.sql.test.enabled", "false")
68+
.config("spark.rapids.sql.test.enabled", "true")
6069
.getOrCreate()
6170
try {
6271
f(spark)

0 commit comments

Comments
 (0)