Skip to content
This repository was archived by the owner on May 11, 2020. It is now read-only.

Commit 49dc095

Browse files
laizysbinet
authored andcommitted
wasm: fix leb128 decoding
1 parent 32f7a52 commit 49dc095

File tree

3 files changed

+246
-45
lines changed

3 files changed

+246
-45
lines changed

validate/validate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ func verifyBody(fn *wasm.FunctionSig, body *wasm.FunctionBody, module *wasm.Modu
175175
vm.setPolymorphic()
176176

177177
case ops.I32Const:
178-
_, err := vm.fetchVarUint()
178+
_, err := vm.fetchVarInt()
179179
if err != nil {
180180
return vm, err
181181
}

wasm/leb128/read.go

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,66 +7,94 @@
77
package leb128
88

99
import (
10+
"errors"
1011
"io"
1112
)
1213

13-
// ReadVarUint32 reads a LEB128 encoded unsigned 32-bit integer from r, and
14-
// returns the integer value, and the error (if any).
15-
func ReadVarUint32(r io.Reader) (uint32, error) {
16-
var (
17-
b = make([]byte, 1)
18-
shift uint
19-
res uint32
20-
err error
21-
)
14+
// readVarUint reads an unsigned integer of size n defined in https://webassembly.github.io/spec/core/binary/values.html#binary-int
15+
// readVarUint panics if n>64.
16+
func readVarUint(r io.Reader, n uint) (uint64, error) {
17+
if n > 64 {
18+
panic(errors.New("leb128: n must <= 64"))
19+
}
20+
p := make([]byte, 1)
21+
var res uint64
22+
var shift uint
2223
for {
23-
if _, err = io.ReadFull(r, b); err != nil {
24-
return res, err
24+
_, err := io.ReadFull(r, p)
25+
if err != nil {
26+
return 0, err
2527
}
28+
b := uint64(p[0])
29+
switch {
30+
case b < 1<<7 && b < 1<<n:
31+
res += (1 << shift) * b
32+
return res, nil
33+
case b >= 1<<7 && n > 7:
34+
res += (1 << shift) * (b - 1<<7)
35+
shift += 7
36+
n -= 7
37+
default:
38+
return 0, errors.New("leb128: invalid uint")
39+
}
40+
}
41+
}
2642

27-
cur := uint32(b[0])
28-
res |= (cur & 0x7f) << (shift)
29-
if cur&0x80 == 0 {
43+
// readVarint reads a signed integer of size n, defined in https://webassembly.github.io/spec/core/binary/values.html#binary-int
44+
// readVarint panics if n>64.
45+
func readVarint(r io.Reader, n uint) (int64, error) {
46+
if n > 64 {
47+
panic(errors.New("leb128: n must <= 64"))
48+
}
49+
p := make([]byte, 1)
50+
var res int64
51+
var shift uint
52+
for {
53+
_, err := io.ReadFull(r, p)
54+
if err != nil {
55+
return 0, err
56+
}
57+
b := int64(p[0])
58+
switch {
59+
case b < 1<<6 && uint64(b) < uint64(1<<(n-1)):
60+
res += (1 << shift) * b
3061
return res, nil
62+
case b >= 1<<6 && b < 1<<7 && uint64(b)+1<<(n-1) >= 1<<7:
63+
res += (1 << shift) * (b - 1<<7)
64+
return res, nil
65+
case b >= 1<<7 && n > 7:
66+
res += (1 << shift) * (b - 1<<7)
67+
shift += 7
68+
n -= 7
69+
default:
70+
return 0, errors.New("leb128: invalid int")
3171
}
32-
shift += 7
3372
}
3473
}
3574

75+
// ReadVarUint32 reads a LEB128 encoded unsigned 32-bit integer from r, and
76+
// returns the integer value, and the error (if any).
77+
func ReadVarUint32(r io.Reader) (uint32, error) {
78+
n, err := readVarUint(r, 32)
79+
if err != nil {
80+
return 0, err
81+
}
82+
return uint32(n), nil
83+
}
84+
3685
// ReadVarint32 reads a LEB128 encoded signed 32-bit integer from r, and
3786
// returns the integer value, and the error (if any).
3887
func ReadVarint32(r io.Reader) (int32, error) {
39-
n, err := ReadVarint64(r)
40-
return int32(n), err
88+
n, err := readVarint(r, 32)
89+
if err != nil {
90+
return 0, err
91+
}
92+
93+
return int32(n), nil
4194
}
4295

4396
// ReadVarint64 reads a LEB128 encoded signed 64-bit integer from r, and
4497
// returns the integer value, and the error (if any).
4598
func ReadVarint64(r io.Reader) (int64, error) {
46-
var (
47-
b = make([]byte, 1)
48-
shift uint
49-
sign int64 = -1
50-
res int64
51-
err error
52-
)
53-
54-
for {
55-
if _, err = io.ReadFull(r, b); err != nil {
56-
return res, err
57-
}
58-
59-
cur := int64(b[0])
60-
res |= (cur & 0x7f) << shift
61-
shift += 7
62-
sign <<= 7
63-
if cur&0x80 == 0 {
64-
break
65-
}
66-
}
67-
68-
if ((sign >> 1) & res) != 0 {
69-
res |= sign
70-
}
71-
return res, nil
99+
return readVarint(r, 64)
72100
}

wasm/leb128/read_test.go

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ package leb128
66

77
import (
88
"bytes"
9+
"crypto/rand"
10+
"errors"
911
"fmt"
1012
"io"
1113
"testing"
@@ -50,8 +52,19 @@ var casesInt = []struct {
5052
{b: []byte{0x80, 0x80, 0x80, 0xfd, 0x07}, v: 2141192192},
5153
}
5254

55+
var varint32Cases = []struct {
56+
b []byte
57+
v int32
58+
}{
59+
{[]byte{0x80, 0x80, 0x80, 0x80, 0x78}, -2147483648}, // int32 min
60+
{[]byte{0xff, 0xff, 0xff, 0xff, 0x07}, 2147483647}, //int32 max
61+
{[]byte{0x80, 0x40}, -8192},
62+
{[]byte{0x80, 0xc0, 0x00}, 8192},
63+
{[]byte{135, 0x01}, 135},
64+
}
65+
5366
func TestReadVarint32(t *testing.T) {
54-
for _, c := range casesInt {
67+
for _, c := range varint32Cases {
5568
t.Run(fmt.Sprint(c.v), func(t *testing.T) {
5669
n, err := ReadVarint32(bytes.NewReader(c.b))
5770
if err != nil {
@@ -70,3 +83,163 @@ func TestReadVarint32Err(t *testing.T) {
7083
t.Fatalf("got err=%v, want=%v", got, want)
7184
}
7285
}
86+
87+
func TestReadWriteInt64(t *testing.T) {
88+
buf := make([]byte, 16)
89+
for i := 0; i < 1000000; i++ {
90+
rand.Read(buf)
91+
reader := bytes.NewReader(buf)
92+
val, err := ReadVarint64(reader)
93+
if err != nil {
94+
continue
95+
}
96+
readLen := len(buf) - reader.Len()
97+
if readLen > (64+6)/7 { // ceil(N/7) bytes
98+
t.Fatalf("read len:%d larger then ceil(N/7) bytes", readLen)
99+
}
100+
101+
buf2 := bytes.NewBuffer(nil)
102+
WriteVarint64(buf2, val)
103+
if readLen <= len(buf2.Bytes()) {
104+
if !bytes.HasPrefix(buf, buf2.Bytes()) {
105+
t.Fatalf(fmt.Sprintf("val:%d, origin buf:%v, buf2: %v", val, buf, buf2.Bytes()))
106+
}
107+
}
108+
}
109+
110+
}
111+
112+
func TestReadWriteInt32(t *testing.T) {
113+
buf := make([]byte, 16)
114+
for i := 0; i < 1000000; i++ {
115+
rand.Read(buf)
116+
117+
reader := bytes.NewReader(buf)
118+
val, err := ReadVarint32(reader)
119+
if err != nil {
120+
continue
121+
}
122+
readLen := len(buf) - reader.Len()
123+
if readLen > (32+6)/7 { // ceil(N/7) bytes
124+
t.Fatalf("read len:%d larger then ceil(N/7) bytes", readLen)
125+
}
126+
127+
buf2 := bytes.NewBuffer(nil)
128+
WriteVarint64(buf2, int64(val))
129+
if readLen <= len(buf2.Bytes()) {
130+
if !bytes.HasPrefix(buf, buf2.Bytes()) {
131+
t.Fatalf(fmt.Sprintf("val:%d, origin buf:%v, buf2: %v", val, buf, buf2.Bytes()))
132+
}
133+
}
134+
}
135+
136+
}
137+
138+
func TestReadWriteUint32(t *testing.T) {
139+
buf := make([]byte, 16)
140+
for i := 0; i < 100000; i++ {
141+
rand.Read(buf)
142+
143+
reader := bytes.NewReader(buf)
144+
val, err := ReadVarUint32(reader)
145+
if err != nil {
146+
continue
147+
}
148+
readLen := len(buf) - reader.Len()
149+
if readLen > (32+6)/7 { // ceil(N/7) bytes
150+
t.Fatalf("read len:%d larger then ceil(N/7) bytes", readLen)
151+
}
152+
153+
buf2 := bytes.NewBuffer(nil)
154+
WriteVarUint32(buf2, val)
155+
if readLen <= len(buf2.Bytes()) {
156+
if !bytes.HasPrefix(buf, buf2.Bytes()) {
157+
t.Fatalf(fmt.Sprintf("val:%d, origin buf:%v, buf2: %v", val, buf, buf2.Bytes()))
158+
}
159+
}
160+
}
161+
}
162+
163+
func TestCompareReadVarint(t *testing.T) {
164+
buf := make([]byte, 16)
165+
for n := uint(1); n <= 64; n++ {
166+
for i := 0; i < 100000; i++ {
167+
rand.Read(buf)
168+
169+
val2, err2 := readVarintRecur(bytes.NewReader(buf), n)
170+
val1, err1 := readVarint(bytes.NewReader(buf), n)
171+
if fmt.Sprint(err1) != fmt.Sprint(err2) || val1 != val2 {
172+
t.Fatalf(fmt.Sprintf("buf: %v, val1:%d, val2: %d", buf, val1, val2))
173+
}
174+
}
175+
176+
}
177+
}
178+
179+
func TestCompareReadVarUint(t *testing.T) {
180+
buf := make([]byte, 16)
181+
for n := uint(1); n <= 64; n++ {
182+
for i := 0; i < 100000; i++ {
183+
rand.Read(buf)
184+
185+
val2, err2 := readVarUintRecur(bytes.NewReader(buf), n)
186+
val1, err1 := readVarUint(bytes.NewReader(buf), n)
187+
if fmt.Sprint(err1) != fmt.Sprint(err2) || val1 != val2 {
188+
t.Fatalf(fmt.Sprintf("buf: %v, val1:%d, val2: %d", buf, val1, val2))
189+
}
190+
}
191+
192+
}
193+
}
194+
195+
func readVarUintRecur(r io.Reader, n uint) (uint64, error) {
196+
if n > 64 {
197+
panic(errors.New("leb128: n must <= 64"))
198+
}
199+
p := make([]byte, 1)
200+
_, err := io.ReadFull(r, p)
201+
if err != nil {
202+
return 0, err
203+
}
204+
b := uint64(p[0])
205+
switch {
206+
case b < 1<<7 && b < 1<<n:
207+
return b, nil
208+
case b >= 1<<7 && n > 7:
209+
m, err := readVarUint(r, n-7)
210+
if err != nil {
211+
return 0, err
212+
}
213+
214+
return (1<<7)*m + (b - 1<<7), nil
215+
default:
216+
return 0, errors.New("leb128: invalid uint")
217+
}
218+
}
219+
220+
func readVarintRecur(r io.Reader, n uint) (int64, error) {
221+
if n > 64 {
222+
panic(errors.New("leb128: n must <= 64"))
223+
}
224+
p := make([]byte, 1)
225+
_, err := io.ReadFull(r, p)
226+
if err != nil {
227+
return 0, err
228+
}
229+
b := int64(p[0])
230+
switch {
231+
case b < 1<<6 && uint64(b) < uint64(1<<(n-1)):
232+
return b, nil
233+
case b >= 1<<6 && b < 1<<7 && uint64(b)+1<<(n-1) >= 1<<7:
234+
return b - 1<<7, nil
235+
case b >= 1<<7 && n > 7:
236+
m, err := readVarint(r, n-7)
237+
if err != nil {
238+
return 0, err
239+
}
240+
241+
return (1<<7)*m + (b - 1<<7), nil
242+
default:
243+
return 0, errors.New("leb128: invalid int")
244+
}
245+
}

0 commit comments

Comments
 (0)