Skip to content

Commit 572c0da

Browse files
committed
address comments
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
1 parent 2e10fbd commit 572c0da

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

sql-plugin/src/main/scala/com/nvidia/spark/rapids/sequencefile/GpuSequenceFileReaders.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,37 @@ private[sequencefile] final class HostBinaryListBufferer(
9696
}
9797

9898
def addBytes(bytes: Array[Byte], offset: Int, len: Int): Unit = {
99+
val newEnd = dataLocation + len
100+
if (newEnd > Int.MaxValue) {
101+
throw new IllegalStateException(
102+
s"Binary column child size $newEnd would exceed INT32 offset limit")
103+
}
99104
growOffsetsIfNeeded()
100-
val end = dataLocation + len
101-
growDataIfNeeded(end)
102-
offsetsBuffer.setInt(numRows.toLong * DType.INT32.getSizeInBytes, dataLocation.toInt)
105+
growDataIfNeeded(newEnd)
106+
val offsetPosition = numRows.toLong * DType.INT32.getSizeInBytes
107+
val startDataLocation = dataLocation
103108
dataBuffer.setBytes(dataLocation, bytes, offset, len)
104-
dataLocation = end
109+
dataLocation = newEnd
110+
// Write offset only after successful data write
111+
offsetsBuffer.setInt(offsetPosition, startDataLocation.toInt)
105112
numRows += 1
106113
}
107114

108115
def addValueBytes(valueBytes: SequenceFile.ValueBytes, len: Int): Unit = {
116+
val newEnd = dataLocation + len
117+
if (newEnd > Int.MaxValue) {
118+
throw new IllegalStateException(
119+
s"Binary column child size $newEnd would exceed INT32 offset limit")
120+
}
109121
growOffsetsIfNeeded()
110-
val end = dataLocation + len
111-
growDataIfNeeded(end)
112-
offsetsBuffer.setInt(numRows.toLong * DType.INT32.getSizeInBytes, dataLocation.toInt)
122+
growDataIfNeeded(newEnd)
123+
val offsetPosition = numRows.toLong * DType.INT32.getSizeInBytes
124+
val startDataLocation = dataLocation
113125
out.seek(dataLocation)
114126
valueBytes.writeUncompressedBytes(dos)
115127
dataLocation = out.getPos
128+
// Write offset only after successful data write
129+
offsetsBuffer.setInt(offsetPosition, startDataLocation.toInt)
116130
numRows += 1
117131
}
118132

@@ -149,6 +163,9 @@ private[sequencefile] final class HostBinaryListBufferer(
149163
}
150164
}
151165
offsetsBuffer = null
166+
// The stream wrappers (out, dos) don't hold independent resources - they just wrap the
167+
// dataBuffer which is now owned by childHost. Setting to null without close() is intentional
168+
// to avoid attempting operations on the transferred buffer.
152169
out = null
153170
dos = null
154171

@@ -327,7 +344,7 @@ class SequenceFilePartitionReader(
327344
val recBytes = recordBytes(keyLen, valueLen)
328345

329346
// If this record doesn't fit, keep it for the next batch (unless it's the first row)
330-
if (rows > 0 && recBytes > 0 && bytes + recBytes > maxBytesPerBatch) {
347+
if (rows > 0 && bytes + recBytes > maxBytesPerBatch) {
331348
pending = Some(makePending(keyLen, valueLen))
332349
keepReading = false
333350
} else {

tests/src/test/scala/com/nvidia/spark/rapids/SequenceFileBinaryFileFormatSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,4 +463,36 @@ class SequenceFileBinaryFileFormatSuite extends AnyFunSuite {
463463
}
464464
}
465465
}
466+
467+
test("Split boundary handling - records starting before boundary are read") {
468+
withTempDir("seqfile-split-test") { tmpDir =>
469+
val file = new File(tmpDir, "split-test.seq")
470+
val conf = new Configuration()
471+
472+
// Create file with multiple records using raw record format (consistent with other tests)
473+
val numRecords = 100
474+
val payloads = (0 until numRecords).map { i =>
475+
s"record-$i-with-some-padding-data".getBytes(StandardCharsets.UTF_8)
476+
}.toArray
477+
478+
writeSequenceFileWithRawRecords(file, conf, payloads)
479+
480+
withSparkSession { spark =>
481+
// Read entire file
482+
val df = spark.read
483+
.format("com.nvidia.spark.rapids.SequenceFileBinaryFileFormat")
484+
.load(file.getAbsolutePath)
485+
486+
val results = df.select("key", "value").collect()
487+
assert(results.length == numRecords,
488+
s"Expected $numRecords records, got ${results.length}")
489+
490+
// Verify all records present and no duplicates
491+
val indices = results.map(r => bytesToInt(r.getAs[Array[Byte]](0))).sorted.toSeq
492+
val expected = (0 until numRecords).toSeq
493+
assert(indices == expected,
494+
"Records missing or duplicated")
495+
}
496+
}
497+
}
466498
}

0 commit comments

Comments
 (0)