Skip to content

Commit 8edb901

Browse files
fixup! Ensure that decompressBuffer doesn't get reallocated by io.CopyN
1 parent eec120a commit 8edb901

File tree

4 files changed

+79
-14
lines changed

4 files changed

+79
-14
lines changed

internal/utils/buf_reader.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ type bytesBufferReader struct {
122122

123123
// NewBytesBufferReader creates a new bytesBufferReader with the given size and allocator.
124124
func NewBytesBufferReader(size int, alloc memory.Allocator) *bytesBufferReader {
125+
if alloc == nil {
126+
alloc = memory.DefaultAllocator
127+
}
125128
buf := alloc.Allocate(size)
126129
return &bytesBufferReader{
127130
alloc: alloc,
@@ -159,6 +162,9 @@ type bufferedReader struct {
159162
// except Peek will expand the internal buffer if needed rather than return
160163
// an error.
161164
func NewBufferedReader(rd Reader, sz int, alloc memory.Allocator) *bufferedReader {
165+
if alloc == nil {
166+
alloc = memory.DefaultAllocator
167+
}
162168
r := &bufferedReader{
163169
alloc: alloc,
164170
rd: rd,

parquet/file/page_reader.go

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package file
1818

1919
import (
20-
"bytes"
2120
"errors"
2221
"fmt"
2322
"io"
@@ -504,21 +503,16 @@ func (p *serializedPageReader) Page() Page {
504503
}
505504

506505
func (p *serializedPageReader) decompress(rd io.Reader, lenCompressed int, buf []byte) ([]byte, error) {
507-
// As of go1.25.3: There is an issue when bytes.Buffer and io.CopyN are used together. io.CopyN
508-
// uses io.LimitReader, which does an additional read on the underlying reader to determine EOF.
509-
// However, bytes.Buffer always attempts to read at least bytes.MinRead (which is 512 bytes) from the
510-
// underlying reader, even if there is less data available than that. So even if there are no more bytes,
511-
// the buffer must have at least bytes.MinRead capacity remaining to avoid a relocation.
512-
if p.decompressBuffer.Cap() < lenCompressed+bytes.MinRead {
513-
p.decompressBuffer.Reserve(lenCompressed + bytes.MinRead)
514-
}
515506
p.decompressBuffer.ResizeNoShrink(lenCompressed)
516-
b := bytes.NewBuffer(p.decompressBuffer.Bytes()[:0])
517-
if _, err := io.CopyN(b, rd, int64(lenCompressed)); err != nil {
507+
data := p.decompressBuffer.Bytes()
508+
n, err := io.ReadFull(rd, data)
509+
if err != nil {
518510
return nil, err
519511
}
512+
if n != lenCompressed {
513+
return nil, fmt.Errorf("parquet: expected to read %d bytes but only read %d", lenCompressed, n)
514+
}
520515

521-
data := p.decompressBuffer.Bytes()
522516
if p.cryptoCtx.DataDecryptor != nil {
523517
data = p.cryptoCtx.DataDecryptor.Decrypt(p.decompressBuffer.Bytes())
524518
}
@@ -563,6 +557,7 @@ func (p *serializedPageReader) GetDictionaryPage() (*DictionaryPage, error) {
563557
io.NewSectionReader(p.r.Outer(), p.dictOffset-p.baseOffset, p.dataOffset-p.baseOffset),
564558
readBufSize,
565559
p.mem)
560+
defer rd.Free()
566561
if err := p.readPageHeader(rd, hdr); err != nil {
567562
return nil, err
568563
}
@@ -774,6 +769,7 @@ func (p *serializedPageReader) Next() bool {
774769

775770
firstRowIdx := p.rowsSeen
776771
p.rowsSeen += int64(dataHeader.GetNumValues())
772+
777773
data, err := p.decompress(p.r, lenCompressed, buf.Bytes())
778774
if err != nil {
779775
p.err = err

parquet/file/row_group_reader.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,13 @@ func (r *RowGroupReader) GetColumnPageReader(i int) (PageReader, error) {
134134
}
135135

136136
if r.fileDecryptor == nil {
137+
stream.Free()
137138
return nil, xerrors.New("column in rowgroup is encrypted, but no file decryptor")
138139
}
139140

140141
const encryptedRowGroupsLimit = 32767
141142
if i > encryptedRowGroupsLimit {
143+
stream.Free()
142144
return nil, xerrors.New("encrypted files cannot contain more than 32767 column chunks")
143145
}
144146

parquet/reader_writer_properties_test.go

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package parquet_test
1818

1919
import (
2020
"bytes"
21+
"errors"
2122
"testing"
2223

2324
"github.com/apache/arrow-go/v18/arrow/memory"
@@ -67,7 +68,67 @@ func TestReaderPropsGetStreamInsufficient(t *testing.T) {
6768
buf := memory.NewBufferBytes([]byte(data))
6869
rdr := bytes.NewReader(buf.Bytes())
6970

70-
props := parquet.NewReaderProperties(nil)
71-
_, err := props.GetStream(rdr, 12, 15)
71+
props1 := parquet.NewReaderProperties(nil)
72+
_, err := props1.GetStream(rdr, 12, 15)
73+
assert.Error(t, err)
74+
}
75+
76+
type mockReaderAt struct{}
77+
78+
func (m *mockReaderAt) ReadAt(p []byte, off int64) (int, error) {
79+
return 0, errors.New("mock error")
80+
}
81+
82+
func TestReaderPropsGetStreamWithAllocator(t *testing.T) {
83+
pool := memory.NewCheckedAllocator(memory.NewGoAllocator())
84+
defer pool.AssertSize(t, 0)
85+
86+
data := "data to read"
87+
buf := memory.NewBufferBytes([]byte(data))
88+
rdr := bytes.NewReader(buf.Bytes())
89+
90+
// no leak on success
91+
props := parquet.NewReaderProperties(pool)
92+
bufRdr, err := props.GetStream(rdr, 0, int64(len(data)))
93+
assert.NoError(t, err)
94+
bufRdr.Free()
95+
96+
// no leak on reader error
97+
_, err = props.GetStream(&mockReaderAt{}, 0, 10)
98+
assert.Error(t, err)
99+
100+
// no leak on insufficient read
101+
_, err = props.GetStream(rdr, 0, int64(len(data)+10))
72102
assert.Error(t, err)
73103
}
104+
105+
func TestReaderPropsGetStreamBufferedWithAllocator(t *testing.T) {
106+
pool := memory.NewCheckedAllocator(memory.NewGoAllocator())
107+
defer pool.AssertSize(t, 0)
108+
109+
data := "data to read"
110+
rdr := bytes.NewReader(memory.NewBufferBytes([]byte(data)).Bytes())
111+
112+
props := parquet.NewReaderProperties(pool)
113+
props.BufferedStreamEnabled = true
114+
115+
buf := make([]byte, len(data))
116+
bufRdr, err := props.GetStream(rdr, 0, int64(len(data)))
117+
assert.NoError(t, err)
118+
_, err = bufRdr.Read(buf)
119+
assert.NoError(t, err)
120+
bufRdr.Free()
121+
122+
bufRdr, err = props.GetStream(&mockReaderAt{}, 0, 10)
123+
assert.NoError(t, err)
124+
_, err = bufRdr.Read(buf)
125+
assert.Error(t, err)
126+
bufRdr.Free()
127+
128+
bufRdr, err = props.GetStream(rdr, 0, int64(len(data)+10))
129+
assert.NoError(t, err)
130+
n, err := bufRdr.Read(buf)
131+
assert.NoError(t, err)
132+
assert.NotEqual(t, len(data)+10, n)
133+
bufRdr.Free()
134+
}

0 commit comments

Comments
 (0)