Skip to content

Commit 9409e9b

Browse files
erictobbee
authored andcommitted
fix: add fuzzer for mp4 box decoding and fix discovered issues
This greatly reduced the memory usage for corrupted mp4 files by checking for sizes of the box compared to the sizes of elements..
1 parent 8910b0a commit 9409e9b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+488
-78
lines changed

av1/av1codecconfigurationrecord.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package av1
22

33
import (
44
"errors"
5+
"fmt"
56
"io"
67

78
"github.com/Eyevinn/mp4ff/bits"
@@ -34,6 +35,11 @@ type CodecConfRec struct {
3435

3536
// DecodeAVCDecConfRec - decode an AV1CodecConfRec
3637
func DecodeAV1CodecConfRec(data []byte) (CodecConfRec, error) {
38+
// Minimum size is 4 bytes for the fixed header fields
39+
if len(data) < 4 {
40+
return CodecConfRec{}, fmt.Errorf("av1C: data size %d is too small (minimum 4 bytes)", len(data))
41+
}
42+
3743
av1drc := CodecConfRec{}
3844

3945
Marker := data[0] >> 7

avc/avcdecoderconfigurationrecord.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ func CreateAVCDecConfRec(spsNalus [][]byte, ppsNalus [][]byte, includePS bool) (
6161

6262
// DecodeAVCDecConfRec - decode an AVCDecConfRec
6363
func DecodeAVCDecConfRec(data []byte) (DecConfRec, error) {
64+
// Check minimum length for fixed header (6 bytes)
65+
if len(data) < 6 {
66+
return DecConfRec{}, fmt.Errorf("data too short for AVC decoder configuration record: %d bytes", len(data))
67+
}
68+
6469
configurationVersion := data[0] // Should be 1
6570
if configurationVersion != 1 {
6671
return DecConfRec{}, fmt.Errorf("AVC decoder configuration record version %d unknown",
@@ -75,29 +80,56 @@ func DecodeAVCDecConfRec(data []byte) (DecConfRec, error) {
7580
}
7681
numSPS := data[5] & 0x1f // 5 bits following 3 reserved bits
7782
pos := 6
83+
7884
spsNALUs := make([][]byte, 0, 1)
7985
for i := 0; i < int(numSPS); i++ {
86+
// Check if we have enough bytes to read NALU length
87+
if pos+2 > len(data) {
88+
return DecConfRec{}, fmt.Errorf("not enough data to read SPS NALU length at position %d", pos)
89+
}
8090
naluLength := int(binary.BigEndian.Uint16(data[pos : pos+2]))
8191
pos += 2
92+
93+
// Check if we have enough bytes to read NALU
94+
if pos+naluLength > len(data) {
95+
return DecConfRec{}, fmt.Errorf("not enough data to read SPS NALU of length %d at position %d", naluLength, pos)
96+
}
8297
spsNALUs = append(spsNALUs, data[pos:pos+naluLength])
8398
pos += naluLength
8499
}
85-
ppsNALUs := make([][]byte, 0, 1)
100+
101+
// Check if we have enough bytes to read numPPS
102+
if pos >= len(data) {
103+
return DecConfRec{}, fmt.Errorf("not enough data to read number of PPS at position %d", pos)
104+
}
86105
numPPS := data[pos]
87106
pos++
107+
108+
ppsNALUs := make([][]byte, 0, 1)
88109
for i := 0; i < int(numPPS); i++ {
110+
// Check if we have enough bytes to read NALU length
111+
if pos+2 > len(data) {
112+
return DecConfRec{}, fmt.Errorf("not enough data to read PPS NALU length at position %d", pos)
113+
}
89114
naluLength := int(binary.BigEndian.Uint16(data[pos : pos+2]))
90115
pos += 2
116+
117+
// Check if we have enough bytes to read NALU
118+
if pos+naluLength > len(data) {
119+
return DecConfRec{}, fmt.Errorf("not enough data to read PPS NALU of length %d at position %d", naluLength, pos)
120+
}
91121
ppsNALUs = append(ppsNALUs, data[pos:pos+naluLength])
92122
pos += naluLength
93123
}
124+
94125
adcr := DecConfRec{
95126
AVCProfileIndication: AVCProfileIndication,
96127
ProfileCompatibility: ProfileCompatibility,
97128
AVCLevelIndication: AVCLevelIndication,
98129
SPSnalus: spsNALUs,
99130
PPSnalus: ppsNALUs,
100131
}
132+
101133
// The rest of this structure may vary
102134
// ISO/IEC 14496-15 2017 says that
103135
// Compatible extensions to this record will extend it and
@@ -114,6 +146,10 @@ func DecodeAVCDecConfRec(data []byte) (DecConfRec, error) {
114146
adcr.NoTrailingInfo = true
115147
return adcr, nil
116148
}
149+
// Check if we have enough bytes for the trailing info
150+
if pos+4 > len(data) {
151+
return DecConfRec{}, fmt.Errorf("not enough data for trailing info at position %d", pos)
152+
}
117153
adcr.ChromaFormat = data[pos] & 0x03
118154
adcr.BitDepthLumaMinus1 = data[pos+1] & 0x07
119155
adcr.BitDepthChromaMinus1 = data[pos+2] & 0x07

bits/fixedslicereader.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ func (s *FixedSliceReader) ReadZeroTerminatedString(maxLen int) string {
167167
}
168168
startPos := s.pos
169169
maxPos := startPos + maxLen
170+
if maxPos > s.len {
171+
maxPos = s.len
172+
}
170173
for {
171174
if s.pos >= maxPos {
172175
s.err = errors.New("did not find terminating zero")
@@ -208,6 +211,10 @@ func (s *FixedSliceReader) ReadPossiblyZeroTerminatedString(maxLen int) (str str
208211
// ReadBytes - read a slice of n bytes
209212
// Return empty slice if n bytes not available
210213
func (s *FixedSliceReader) ReadBytes(n int) []byte {
214+
if n < 0 {
215+
s.err = fmt.Errorf("attempt to read negative number of bytes: %d", n)
216+
return []byte{}
217+
}
211218
if s.err != nil {
212219
return []byte{}
213220
}

mp4/box.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ func DecodeHeader(r io.Reader) (BoxHeader, error) {
198198
} else if size == 0 {
199199
return BoxHeader{}, fmt.Errorf("Size 0, meaning to end of file, not supported")
200200
}
201+
if uint64(headerLen) > size {
202+
return BoxHeader{}, fmt.Errorf("box header size %d exceeds box size %d", headerLen, size)
203+
}
201204
return BoxHeader{string(buf[4:8]), size, headerLen}, nil
202205
}
203206

@@ -380,14 +383,17 @@ func makebuf(b Box) []byte {
380383

381384
// readBoxBody reads complete box body. Returns error if not possible
382385
func readBoxBody(r io.Reader, h BoxHeader) ([]byte, error) {
383-
bodyLen := h.Size - uint64(h.Hdrlen)
384-
if bodyLen == 0 {
386+
hdrLen := uint64(h.Hdrlen)
387+
if hdrLen == h.Size {
385388
return nil, nil
386389
}
387-
body := make([]byte, bodyLen)
388-
_, err := io.ReadFull(r, body)
390+
bodyLen := h.Size - hdrLen
391+
body, err := io.ReadAll(io.LimitReader(r, int64(bodyLen)))
389392
if err != nil {
390393
return nil, err
391394
}
395+
if len(body) != int(bodyLen) {
396+
return nil, fmt.Errorf("read box body length %d does not match expected length %d", len(body), bodyLen)
397+
}
392398
return body, nil
393399
}

mp4/boxsr.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,10 @@ func DecodeHeaderSR(sr bits.SliceReader) (BoxHeader, error) {
189189
} else if size == 0 {
190190
return BoxHeader{}, fmt.Errorf("Size 0, meaning to end of file, not supported")
191191
}
192-
return BoxHeader{boxType, size, headerLen}, nil
192+
if uint64(headerLen) > size {
193+
return BoxHeader{}, fmt.Errorf("box header size %d exceeds box size %d", headerLen, size)
194+
}
195+
return BoxHeader{boxType, size, headerLen}, sr.AccError()
193196
}
194197

195198
// DecodeFile - parse and decode a file from reader r with optional file options.

mp4/co64.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,23 @@ func DecodeCo64(hdr BoxHeader, startPos uint64, r io.Reader) (Box, error) {
3232
func DecodeCo64SR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, error) {
3333
versionAndFlags := sr.ReadUint32()
3434
nrEntries := sr.ReadUint32()
35+
3536
b := &Co64Box{
36-
Version: byte(versionAndFlags >> 24),
37-
Flags: versionAndFlags & flagsMask,
38-
ChunkOffset: make([]uint64, nrEntries),
37+
Version: byte(versionAndFlags >> 24),
38+
Flags: versionAndFlags & flagsMask,
39+
}
40+
41+
if hdr.Size != b.expectedSize(nrEntries) {
42+
return nil, fmt.Errorf("co64: expected size %d, got %d", b.expectedSize(nrEntries), hdr.Size)
3943
}
4044

45+
b.ChunkOffset = make([]uint64, nrEntries)
46+
4147
for i := uint32(0); i < nrEntries; i++ {
4248
b.ChunkOffset[i] = sr.ReadUint64()
49+
if sr.AccError() != nil {
50+
return nil, sr.AccError()
51+
}
4352
}
4453
return b, sr.AccError()
4554
}
@@ -49,9 +58,13 @@ func (b *Co64Box) Type() string {
4958
return "co64"
5059
}
5160

61+
func (b *Co64Box) expectedSize(nrEntries uint32) uint64 {
62+
return uint64(boxHeaderSize + 8 + nrEntries*8)
63+
}
64+
5265
// Size - box-specific size
5366
func (b *Co64Box) Size() uint64 {
54-
return uint64(boxHeaderSize + 8 + len(b.ChunkOffset)*8)
67+
return b.expectedSize(uint32(len(b.ChunkOffset)))
5568
}
5669

5770
// Encode - write box to w

mp4/ctts.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,17 @@ func DecodeCttsSR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, err
3535
entryCount := sr.ReadUint32()
3636

3737
b := &CttsBox{
38-
Version: byte(versionAndFlags >> 24),
39-
Flags: versionAndFlags & flagsMask,
40-
EndSampleNr: make([]uint32, entryCount+1),
41-
SampleOffset: make([]int32, entryCount),
38+
Version: byte(versionAndFlags >> 24),
39+
Flags: versionAndFlags & flagsMask,
4240
}
4341

42+
if hdr.Size != b.expectedSize(entryCount) {
43+
return nil, fmt.Errorf("ctts: expected size %d, got %d", b.expectedSize(entryCount), hdr.Size)
44+
}
45+
46+
b.EndSampleNr = make([]uint32, entryCount+1)
47+
b.SampleOffset = make([]int32, entryCount)
48+
4449
var endSampleNr uint32 = 0
4550
b.EndSampleNr[0] = endSampleNr
4651
for i := 0; i < int(entryCount); i++ {
@@ -58,7 +63,12 @@ func (b *CttsBox) Type() string {
5863

5964
// Size - calculated size of box
6065
func (b *CttsBox) Size() uint64 {
61-
return uint64(boxHeaderSize + 8 + len(b.SampleOffset)*8)
66+
return b.expectedSize(uint32(len(b.SampleOffset)))
67+
}
68+
69+
// expectedSize - calculate size for a given entry count
70+
func (b *CttsBox) expectedSize(entryCount uint32) uint64 {
71+
return uint64(boxHeaderSize + 8 + uint64(entryCount)*8) // 8 = version + flags + entryCount, 8 = sampleCount(4) + sampleOffset(4)
6272
}
6373

6474
// Encode - write box to w

mp4/elst.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,14 @@ func DecodeElstSR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, err
4141
b := &ElstBox{
4242
Version: version,
4343
Flags: versionAndFlags & flagsMask,
44-
Entries: make([]ElstEntry, entryCount),
4544
}
4645

46+
if hdr.Size != b.expectedSize(entryCount) {
47+
return nil, fmt.Errorf("elst: expected size %d, got %d", b.expectedSize(entryCount), hdr.Size)
48+
}
49+
50+
b.Entries = make([]ElstEntry, entryCount)
51+
4752
if version == 1 {
4853
for i := 0; i < int(entryCount); i++ {
4954
b.Entries[i].SegmentDuration = sr.ReadUint64()
@@ -71,10 +76,15 @@ func (b *ElstBox) Type() string {
7176

7277
// Size - calculated size of box
7378
func (b *ElstBox) Size() uint64 {
79+
return b.expectedSize(uint32(len(b.Entries)))
80+
}
81+
82+
// expectedSize - calculate size for a given entry count
83+
func (b *ElstBox) expectedSize(entryCount uint32) uint64 {
7484
if b.Version == 1 {
75-
return uint64(boxHeaderSize + 8 + len(b.Entries)*20)
85+
return uint64(boxHeaderSize + 8 + uint64(entryCount)*20) // 8 = version + flags + entryCount, 20 = uint64 + int64 + 2*int16
7686
}
77-
return uint64(boxHeaderSize + 8 + len(b.Entries)*12) // m.Version == 0
87+
return uint64(boxHeaderSize + 8 + uint64(entryCount)*12) // 8 = version + flags + entryCount, 12 = uint32 + int32 + 2*int16
7888
}
7989

8090
// Encode - write box to w

mp4/eventmessage.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,17 @@ func DecodeSilbSR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, err
166166
b.Version = uint8(versionAndFlags >> 24)
167167
b.Flags = versionAndFlags & flagsMask
168168
nrSchemes := sr.ReadUint32()
169-
b.Schemes = make([]SilbEntry, nrSchemes)
170169
for i := uint32(0); i < nrSchemes; i++ {
171170
schemeIdURI := sr.ReadZeroTerminatedString(int(hdr.payloadLen()) - 8)
172171
value := sr.ReadZeroTerminatedString(int(hdr.payloadLen()) - 8 - len(schemeIdURI) - 1)
173172
atLeastOneFlag := sr.ReadUint8() == 1
174-
b.Schemes[i] = SilbEntry{
173+
b.Schemes = append(b.Schemes, SilbEntry{
175174
SchemeIdURI: schemeIdURI,
176175
Value: value,
177176
AtLeastOneFlag: atLeastOneFlag,
177+
})
178+
if sr.AccError() != nil {
179+
return nil, sr.AccError()
178180
}
179181
}
180182
b.OtherSchemesFlag = sr.ReadUint8() == 1

mp4/fuzz_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//go:build go1.18
2+
// +build go1.18
3+
4+
package mp4
5+
6+
import (
7+
"bytes"
8+
"context"
9+
"errors"
10+
"io"
11+
"os"
12+
"runtime"
13+
"strings"
14+
"testing"
15+
"time"
16+
)
17+
18+
func monitorMemory(ctx context.Context, t *testing.T, memoryLimit int) {
19+
go func() {
20+
timer := time.NewTicker(500 * time.Millisecond)
21+
defer timer.Stop()
22+
var m runtime.MemStats
23+
24+
for {
25+
select {
26+
case <-ctx.Done():
27+
return
28+
case <-timer.C:
29+
runtime.ReadMemStats(&m)
30+
if m.Alloc > uint64(memoryLimit) {
31+
t.Logf("memory limit exceeded: %d > %d", m.Alloc, memoryLimit)
32+
t.Fail()
33+
return
34+
}
35+
}
36+
}
37+
}()
38+
}
39+
40+
func FuzzDecodeBox(f *testing.F) {
41+
entries, err := os.ReadDir("testdata")
42+
if err != nil {
43+
f.Fatal(err)
44+
}
45+
46+
for _, entry := range entries {
47+
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".mp4") {
48+
testData, err := os.ReadFile("testdata/" + entry.Name())
49+
if err != nil {
50+
f.Fatal(err)
51+
}
52+
f.Add(testData)
53+
}
54+
}
55+
56+
f.Fuzz(func(t *testing.T, b []byte) {
57+
if t.Name() == "FuzzDecodeBox/75565444c6c2f1dd" {
58+
t.Skip("There is a bug in SencBox.Size() that needs to be fixed for " + t.Name())
59+
}
60+
61+
ctx, cancel := context.WithCancel(context.Background())
62+
defer cancel()
63+
monitorMemory(ctx, t, 500*1024*1024) // 500MB
64+
65+
r := bytes.NewReader(b)
66+
67+
var pos uint64 = 0
68+
for {
69+
box, err := DecodeBox(pos, r)
70+
if err != nil {
71+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
72+
break
73+
}
74+
}
75+
if box == nil {
76+
break
77+
}
78+
pos += box.Size()
79+
}
80+
})
81+
}

0 commit comments

Comments
 (0)