Skip to content
This repository was archived by the owner on May 11, 2020. It is now read-only.

Commit 46451e6

Browse files
laizysbinet
authored andcommitted
wasm: fix arbitrary memory allocation attack
1 parent f8c9745 commit 46451e6

File tree

3 files changed

+92
-58
lines changed

3 files changed

+92
-58
lines changed

wasm/read.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,34 @@
55
package wasm
66

77
import (
8+
"bytes"
89
"encoding/binary"
910
"io"
1011

1112
"github.com/go-interpreter/wagon/wasm/leb128"
1213
)
1314

14-
func readBytes(r io.Reader, n int) ([]byte, error) {
15-
bytes := make([]byte, n)
16-
_, err := io.ReadFull(r, bytes)
17-
if err != nil {
18-
return bytes, err
15+
// to avoid memory attack
16+
const maxInitialCap = 10 * 1024
17+
18+
func getInitialCap(count uint32) uint32 {
19+
if count > maxInitialCap {
20+
return maxInitialCap
1921
}
22+
return count
23+
}
2024

21-
return bytes, nil
25+
func readBytes(r io.Reader, n int) ([]byte, error) {
26+
if n == 0 {
27+
return nil, nil
28+
}
29+
limited := io.LimitReader(r, int64(n))
30+
buf := &bytes.Buffer{}
31+
num, _ := buf.ReadFrom(limited)
32+
if num == int64(n) {
33+
return buf.Bytes(), nil
34+
}
35+
return nil, io.ErrUnexpectedEOF
2236
}
2337

2438
func readBytesUint(r io.Reader) ([]byte, error) {

wasm/section.go

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ func (sr *sectionsReader) readSection(r *readpos.ReadPos) (bool, error) {
159159
s.Start = r.CurPos
160160

161161
sectionBytes := new(bytes.Buffer)
162-
sectionBytes.Grow(int(payloadDataLen))
162+
163+
sectionBytes.Grow(int(getInitialCap(payloadDataLen)))
163164
sectionReader := io.LimitReader(io.TeeReader(r, sectionBytes), int64(payloadDataLen))
164165

165166
var sec Section
@@ -295,11 +296,14 @@ func (s *SectionTypes) ReadPayload(r io.Reader) error {
295296
if err != nil {
296297
return err
297298
}
298-
s.Entries = make([]FunctionSig, int(count))
299-
for i := range s.Entries {
300-
if err = s.Entries[i].UnmarshalWASM(r); err != nil {
299+
300+
s.Entries = make([]FunctionSig, 0, getInitialCap(count))
301+
for i := uint32(0); i < count; i++ {
302+
var sig FunctionSig
303+
if err := sig.UnmarshalWASM(r); err != nil {
301304
return err
302305
}
306+
s.Entries = append(s.Entries, sig)
303307
}
304308
return nil
305309
}
@@ -334,12 +338,14 @@ func (s *SectionImports) ReadPayload(r io.Reader) error {
334338
if err != nil {
335339
return err
336340
}
337-
s.Entries = make([]ImportEntry, count)
338-
for i := range s.Entries {
339-
err = s.Entries[i].UnmarshalWASM(r)
340-
if err != nil {
341+
342+
s.Entries = make([]ImportEntry, 0, getInitialCap(count))
343+
for i := uint32(0); i < count; i++ {
344+
var entry ImportEntry
345+
if err := entry.UnmarshalWASM(r); err != nil {
341346
return err
342347
}
348+
s.Entries = append(s.Entries, entry)
343349
}
344350
return nil
345351
}
@@ -440,13 +446,13 @@ func (s *SectionFunctions) ReadPayload(r io.Reader) error {
440446
if err != nil {
441447
return err
442448
}
443-
s.Types = make([]uint32, count)
444-
for i := range s.Types {
449+
s.Types = make([]uint32, 0, getInitialCap(count))
450+
for i := uint32(0); i < count; i++ {
445451
t, err := leb128.ReadVarUint32(r)
446452
if err != nil {
447453
return err
448454
}
449-
s.Types[i] = t
455+
s.Types = append(s.Types, t)
450456
}
451457
return nil
452458
}
@@ -478,12 +484,14 @@ func (s *SectionTables) ReadPayload(r io.Reader) error {
478484
if err != nil {
479485
return err
480486
}
481-
s.Entries = make([]Table, count)
482-
for i := range s.Entries {
483-
err = s.Entries[i].UnmarshalWASM(r)
484-
if err != nil {
487+
488+
s.Entries = make([]Table, 0, getInitialCap(count))
489+
for i := uint32(0); i < count; i++ {
490+
var entry Table
491+
if err = entry.UnmarshalWASM(r); err != nil {
485492
return err
486493
}
494+
s.Entries = append(s.Entries, entry)
487495
}
488496
return nil
489497
}
@@ -515,12 +523,13 @@ func (s *SectionMemories) ReadPayload(r io.Reader) error {
515523
if err != nil {
516524
return err
517525
}
518-
s.Entries = make([]Memory, count)
519-
for i := range s.Entries {
520-
err = s.Entries[i].UnmarshalWASM(r)
521-
if err != nil {
526+
s.Entries = make([]Memory, 0, getInitialCap(count))
527+
for i := uint32(0); i < count; i++ {
528+
var entry Memory
529+
if err = entry.UnmarshalWASM(r); err != nil {
522530
return err
523531
}
532+
s.Entries = append(s.Entries, entry)
524533
}
525534
return nil
526535
}
@@ -552,13 +561,15 @@ func (s *SectionGlobals) ReadPayload(r io.Reader) error {
552561
if err != nil {
553562
return err
554563
}
555-
s.Globals = make([]GlobalEntry, count)
564+
565+
s.Globals = make([]GlobalEntry, 0, getInitialCap(count))
556566
logger.Printf("%d global entries\n", count)
557-
for i := range s.Globals {
558-
err = s.Globals[i].UnmarshalWASM(r)
559-
if err != nil {
567+
for i := uint32(0); i < count; i++ {
568+
var global GlobalEntry
569+
if err = global.UnmarshalWASM(r); err != nil {
560570
return err
561571
}
572+
s.Globals = append(s.Globals, global)
562573
}
563574
return nil
564575
}
@@ -616,7 +627,8 @@ func (s *SectionExports) ReadPayload(r io.Reader) error {
616627
if err != nil {
617628
return err
618629
}
619-
s.Entries = make(map[string]ExportEntry, count)
630+
631+
s.Entries = make(map[string]ExportEntry, getInitialCap(count))
620632
for i := uint32(0); i < count; i++ {
621633
var entry ExportEntry
622634
err = entry.UnmarshalWASM(r)
@@ -735,12 +747,14 @@ func (s *SectionElements) ReadPayload(r io.Reader) error {
735747
if err != nil {
736748
return err
737749
}
738-
s.Entries = make([]ElementSegment, count)
739-
for i := range s.Entries {
740-
err = s.Entries[i].UnmarshalWASM(r)
741-
if err != nil {
750+
751+
s.Entries = make([]ElementSegment, 0, getInitialCap(count))
752+
for i := uint32(0); i < count; i++ {
753+
var element ElementSegment
754+
if err = element.UnmarshalWASM(r); err != nil {
742755
return err
743756
}
757+
s.Entries = append(s.Entries, element)
744758
}
745759
return nil
746760
}
@@ -778,14 +792,13 @@ func (s *ElementSegment) UnmarshalWASM(r io.Reader) error {
778792
if err != nil {
779793
return err
780794
}
781-
s.Elems = make([]uint32, numElems)
782-
783-
for i := range s.Elems {
795+
s.Elems = make([]uint32, 0, getInitialCap(numElems))
796+
for i := uint32(0); i < numElems; i++ {
784797
e, err := leb128.ReadVarUint32(r)
785798
if err != nil {
786799
return err
787800
}
788-
s.Elems[i] = e
801+
s.Elems = append(s.Elems, e)
789802
}
790803

791804
return nil
@@ -825,14 +838,16 @@ func (s *SectionCode) ReadPayload(r io.Reader) error {
825838
if err != nil {
826839
return err
827840
}
828-
s.Bodies = make([]FunctionBody, count)
841+
s.Bodies = make([]FunctionBody, 0, getInitialCap(count))
829842
logger.Printf("%d function bodies\n", count)
830843

831-
for i := range s.Bodies {
844+
for i := uint32(0); i < count; i++ {
832845
logger.Printf("Reading function %d\n", i)
833-
if err = s.Bodies[i].UnmarshalWASM(r); err != nil {
846+
var body FunctionBody
847+
if err = body.UnmarshalWASM(r); err != nil {
834848
return err
835849
}
850+
s.Bodies = append(s.Bodies, body)
836851
}
837852
return nil
838853
}
@@ -864,9 +879,8 @@ func (f *FunctionBody) UnmarshalWASM(r io.Reader) error {
864879
return err
865880
}
866881

867-
body := make([]byte, bodySize)
868-
869-
if _, err = io.ReadFull(r, body); err != nil {
882+
body, err := readBytes(r, int(bodySize))
883+
if err != nil {
870884
return err
871885
}
872886

@@ -876,12 +890,14 @@ func (f *FunctionBody) UnmarshalWASM(r io.Reader) error {
876890
if err != nil {
877891
return err
878892
}
879-
f.Locals = make([]LocalEntry, localCount)
893+
f.Locals = make([]LocalEntry, 0, getInitialCap(localCount))
880894

881-
for i := range f.Locals {
882-
if err = f.Locals[i].UnmarshalWASM(bytesReader); err != nil {
895+
for i := uint32(0); i < localCount; i++ {
896+
var local LocalEntry
897+
if err = local.UnmarshalWASM(bytesReader); err != nil {
883898
return err
884899
}
900+
f.Locals = append(f.Locals, local)
885901
}
886902

887903
logger.Printf("bodySize: %d, localCount: %d\n", bodySize, localCount)
@@ -961,11 +977,13 @@ func (s *SectionData) ReadPayload(r io.Reader) error {
961977
if err != nil {
962978
return err
963979
}
964-
s.Entries = make([]DataSegment, count)
965-
for i := range s.Entries {
966-
if err = s.Entries[i].UnmarshalWASM(r); err != nil {
980+
s.Entries = make([]DataSegment, 0, getInitialCap(count))
981+
for i := uint32(0); i < count; i++ {
982+
var entry DataSegment
983+
if err = entry.UnmarshalWASM(r); err != nil {
967984
return err
968985
}
986+
s.Entries = append(s.Entries, entry)
969987
}
970988
return nil
971989
}

wasm/types.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,26 +134,28 @@ func (f *FunctionSig) UnmarshalWASM(r io.Reader) error {
134134
if err != nil {
135135
return err
136136
}
137-
f.ParamTypes = make([]ValueType, paramCount)
137+
f.ParamTypes = make([]ValueType, 0, getInitialCap(paramCount))
138138

139-
for i := range f.ParamTypes {
140-
err = f.ParamTypes[i].UnmarshalWASM(r)
141-
if err != nil {
139+
for i := uint32(0); i < paramCount; i++ {
140+
var v ValueType
141+
if err = v.UnmarshalWASM(r); err != nil {
142142
return err
143143
}
144+
f.ParamTypes = append(f.ParamTypes, v)
144145
}
145146

146147
returnCount, err := leb128.ReadVarUint32(r)
147148
if err != nil {
148149
return err
149150
}
150151

151-
f.ReturnTypes = make([]ValueType, returnCount)
152-
for i := range f.ReturnTypes {
153-
err = f.ReturnTypes[i].UnmarshalWASM(r)
154-
if err != nil {
152+
f.ReturnTypes = make([]ValueType, 0, getInitialCap(returnCount))
153+
for i := uint32(0); i < returnCount; i++ {
154+
var v ValueType
155+
if err = v.UnmarshalWASM(r); err != nil {
155156
return err
156157
}
158+
f.ReturnTypes = append(f.ReturnTypes, v)
157159
}
158160

159161
return nil

0 commit comments

Comments
 (0)