Skip to content

Commit 96fb839

Browse files
authored
rlp: improve nil pointer handling (#20064)
* rlp: improve nil pointer handling In both encoder and decoder, the rules for encoding nil pointers were a bit hard to understand, and didn't leave much choice. Since RLP allows two empty values (empty list, empty string), any protocol built on RLP must choose either of these values to represent the null value in a certain context. This change adds choice in the form of two new struct tags, "nilString" and "nilList". These can be used to specify how a nil pointer value is encoded. The "nil" tag still exists, but its implementation is now explicit and defines exactly how nil pointers are handled in a single place. Another important change in this commit is how nil pointers and the Encoder interface interact. The EncodeRLP method was previously called even on nil values, which was supposed to give users a choice of how their value would be handled when nil. It turns out this is a stupid idea. If you create a network protocol containing an object defined in another package, it's better to be able to say that the object should be a list or string when nil in the definition of the protocol message rather than defining the encoding of nil on the object itself. As of this commit, the encoding rules for pointers now take precedence over the Encoder interface rule. I think the "nil" tag will work fine for most cases. For special kinds of objects which are a struct in Go but strings in RLP, code using the object can specify the desired encoding of nil using the "nilString" and "nilList" tags. * rlp: propagate struct field type errors If a struct contained fields of undecodable type, the encoder and decoder would panic instead of returning an error. Fix this by propagating type errors in makeStruct{Writer,Decoder} and add a test.
1 parent 3b6c990 commit 96fb839

File tree

7 files changed

+416
-244
lines changed

7 files changed

+416
-244
lines changed

rlp/decode.go

Lines changed: 54 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -55,81 +55,23 @@ var (
5555
}
5656
)
5757

58-
// Decoder is implemented by types that require custom RLP
59-
// decoding rules or need to decode into private fields.
58+
// Decoder is implemented by types that require custom RLP decoding rules or need to decode
59+
// into private fields.
6060
//
61-
// The DecodeRLP method should read one value from the given
62-
// Stream. It is not forbidden to read less or more, but it might
63-
// be confusing.
61+
// The DecodeRLP method should read one value from the given Stream. It is not forbidden to
62+
// read less or more, but it might be confusing.
6463
type Decoder interface {
6564
DecodeRLP(*Stream) error
6665
}
6766

68-
// Decode parses RLP-encoded data from r and stores the result in the
69-
// value pointed to by val. Val must be a non-nil pointer. If r does
70-
// not implement ByteReader, Decode will do its own buffering.
67+
// Decode parses RLP-encoded data from r and stores the result in the value pointed to by
68+
// val. Please see package-level documentation for the decoding rules. Val must be a
69+
// non-nil pointer.
7170
//
72-
// Decode uses the following type-dependent decoding rules:
71+
// If r does not implement ByteReader, Decode will do its own buffering.
7372
//
74-
// If the type implements the Decoder interface, decode calls
75-
// DecodeRLP.
76-
//
77-
// To decode into a pointer, Decode will decode into the value pointed
78-
// to. If the pointer is nil, a new value of the pointer's element
79-
// type is allocated. If the pointer is non-nil, the existing value
80-
// will be reused.
81-
//
82-
// To decode into a struct, Decode expects the input to be an RLP
83-
// list. The decoded elements of the list are assigned to each public
84-
// field in the order given by the struct's definition. The input list
85-
// must contain an element for each decoded field. Decode returns an
86-
// error if there are too few or too many elements.
87-
//
88-
// The decoding of struct fields honours certain struct tags, "tail",
89-
// "nil" and "-".
90-
//
91-
// The "-" tag ignores fields.
92-
//
93-
// For an explanation of "tail", see the example.
94-
//
95-
// The "nil" tag applies to pointer-typed fields and changes the decoding
96-
// rules for the field such that input values of size zero decode as a nil
97-
// pointer. This tag can be useful when decoding recursive types.
98-
//
99-
// type StructWithEmptyOK struct {
100-
// Foo *[20]byte `rlp:"nil"`
101-
// }
102-
//
103-
// To decode into a slice, the input must be a list and the resulting
104-
// slice will contain the input elements in order. For byte slices,
105-
// the input must be an RLP string. Array types decode similarly, with
106-
// the additional restriction that the number of input elements (or
107-
// bytes) must match the array's length.
108-
//
109-
// To decode into a Go string, the input must be an RLP string. The
110-
// input bytes are taken as-is and will not necessarily be valid UTF-8.
111-
//
112-
// To decode into an unsigned integer type, the input must also be an RLP
113-
// string. The bytes are interpreted as a big endian representation of
114-
// the integer. If the RLP string is larger than the bit size of the
115-
// type, Decode will return an error. Decode also supports *big.Int.
116-
// There is no size limit for big integers.
117-
//
118-
// To decode into a boolean, the input must contain an unsigned integer
119-
// of value zero (false) or one (true).
120-
//
121-
// To decode into an interface value, Decode stores one of these
122-
// in the value:
123-
//
124-
// []interface{}, for RLP lists
125-
// []byte, for RLP strings
126-
//
127-
// Non-empty interface types are not supported, nor are signed integers,
128-
// floating point numbers, maps, channels and functions.
129-
//
130-
// Note that Decode does not set an input limit for all readers
131-
// and may be vulnerable to panics cause by huge value sizes. If
132-
// you need an input limit, use
73+
// Note that Decode does not set an input limit for all readers and may be vulnerable to
74+
// panics cause by huge value sizes. If you need an input limit, use
13375
//
13476
// NewStream(r, limit).Decode(val)
13577
func Decode(r io.Reader, val interface{}) error {
@@ -140,9 +82,8 @@ func Decode(r io.Reader, val interface{}) error {
14082
return stream.Decode(val)
14183
}
14284

143-
// DecodeBytes parses RLP data from b into val.
144-
// Please see the documentation of Decode for the decoding rules.
145-
// The input must contain exactly one value and no trailing data.
85+
// DecodeBytes parses RLP data from b into val. Please see package-level documentation for
86+
// the decoding rules. The input must contain exactly one value and no trailing data.
14687
func DecodeBytes(b []byte, val interface{}) error {
14788
r := bytes.NewReader(b)
14889

@@ -211,14 +152,15 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
211152
switch {
212153
case typ == rawValueType:
213154
return decodeRawValue, nil
214-
case typ.Implements(decoderInterface):
215155
return decodeDecoder, nil
216-
case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface):
217-
return decodeDecoderNoPtr, nil
218156
case typ.AssignableTo(reflect.PtrTo(bigInt)):
219157
return decodeBigInt, nil
220158
case typ.AssignableTo(bigInt):
221159
return decodeBigIntNoPtr, nil
160+
case kind == reflect.Ptr:
161+
return makePtrDecoder(typ, tags)
162+
case reflect.PtrTo(typ).Implements(decoderInterface):
163+
return decodeDecoder, nil
222164
case isUint(kind):
223165
return decodeUint, nil
224166
case kind == reflect.Bool:
@@ -229,11 +171,6 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
229171
return makeListDecoder(typ, tags)
230172
case kind == reflect.Struct:
231173
return makeStructDecoder(typ)
232-
case kind == reflect.Ptr:
233-
if tags.nilOK {
234-
return makeOptionalPtrDecoder(typ)
235-
}
236-
return makePtrDecoder(typ)
237174
case kind == reflect.Interface:
238175
return decodeInterface, nil
239176
default:
@@ -448,6 +385,11 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
448385
if err != nil {
449386
return nil, err
450387
}
388+
for _, f := range fields {
389+
if f.info.decoderErr != nil {
390+
return nil, structFieldError{typ, f.index, f.info.decoderErr}
391+
}
392+
}
451393
dec := func(s *Stream, val reflect.Value) (err error) {
452394
if _, err := s.List(); err != nil {
453395
return wrapStreamError(err, typ)
@@ -465,15 +407,22 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
465407
return dec, nil
466408
}
467409

468-
// makePtrDecoder creates a decoder that decodes into
469-
// the pointer's element type.
470-
func makePtrDecoder(typ reflect.Type) (decoder, error) {
410+
// makePtrDecoder creates a decoder that decodes into the pointer's element type.
411+
func makePtrDecoder(typ reflect.Type, tag tags) (decoder, error) {
471412
etype := typ.Elem()
472413
etypeinfo := cachedTypeInfo1(etype, tags{})
473-
if etypeinfo.decoderErr != nil {
414+
switch {
415+
case etypeinfo.decoderErr != nil:
474416
return nil, etypeinfo.decoderErr
417+
case !tag.nilOK:
418+
return makeSimplePtrDecoder(etype, etypeinfo), nil
419+
default:
420+
return makeNilPtrDecoder(etype, etypeinfo, tag.nilKind), nil
475421
}
476-
dec := func(s *Stream, val reflect.Value) (err error) {
422+
}
423+
424+
func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder {
425+
return func(s *Stream, val reflect.Value) (err error) {
477426
newval := val
478427
if val.IsNil() {
479428
newval = reflect.New(etype)
@@ -483,30 +432,35 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) {
483432
}
484433
return err
485434
}
486-
return dec, nil
487435
}
488436

489-
// makeOptionalPtrDecoder creates a decoder that decodes empty values
490-
// as nil. Non-empty values are decoded into a value of the element type,
491-
// just like makePtrDecoder does.
437+
// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty
438+
// values are decoded into a value of the element type, just like makePtrDecoder does.
492439
//
493440
// This decoder is used for pointer-typed struct fields with struct tag "nil".
494-
func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
495-
etype := typ.Elem()
496-
etypeinfo := cachedTypeInfo1(etype, tags{})
497-
if etypeinfo.decoderErr != nil {
498-
return nil, etypeinfo.decoderErr
499-
}
500-
dec := func(s *Stream, val reflect.Value) (err error) {
441+
func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, nilKind Kind) decoder {
442+
typ := reflect.PtrTo(etype)
443+
nilPtr := reflect.Zero(typ)
444+
return func(s *Stream, val reflect.Value) (err error) {
501445
kind, size, err := s.Kind()
502-
if err != nil || size == 0 && kind != Byte {
446+
if err != nil {
447+
val.Set(nilPtr)
448+
return wrapStreamError(err, typ)
449+
}
450+
// Handle empty values as a nil pointer.
451+
if kind != Byte && size == 0 {
452+
if kind != nilKind {
453+
return &decodeError{
454+
msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind),
455+
typ: typ,
456+
}
457+
}
503458
// rearm s.Kind. This is important because the input
504459
// position must advance to the next value even though
505460
// we don't read anything.
506461
s.kind = -1
507-
// set the pointer to nil.
508-
val.Set(reflect.Zero(typ))
509-
return err
462+
val.Set(nilPtr)
463+
return nil
510464
}
511465
newval := val
512466
if val.IsNil() {
@@ -517,7 +471,6 @@ func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
517471
}
518472
return err
519473
}
520-
return dec, nil
521474
}
522475

523476
var ifsliceType = reflect.TypeOf([]interface{}{})
@@ -546,21 +499,8 @@ func decodeInterface(s *Stream, val reflect.Value) error {
546499
return nil
547500
}
548501

549-
// This decoder is used for non-pointer values of types
550-
// that implement the Decoder interface using a pointer receiver.
551-
func decodeDecoderNoPtr(s *Stream, val reflect.Value) error {
552-
return val.Addr().Interface().(Decoder).DecodeRLP(s)
553-
}
554-
555502
func decodeDecoder(s *Stream, val reflect.Value) error {
556-
// Decoder instances are not handled using the pointer rule if the type
557-
// implements Decoder with pointer receiver (i.e. always)
558-
// because it might handle empty values specially.
559-
// We need to allocate one here in this case, like makePtrDecoder does.
560-
if val.Kind() == reflect.Ptr && val.IsNil() {
561-
val.Set(reflect.New(val.Type().Elem()))
562-
}
563-
return val.Interface().(Decoder).DecodeRLP(s)
503+
return val.Addr().Interface().(Decoder).DecodeRLP(s)
564504
}
565505

566506
// Kind represents the kind of value contained in an RLP stream.

rlp/decode_test.go

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ type recstruct struct {
327327
Child *recstruct `rlp:"nil"`
328328
}
329329

330+
type invalidNilTag struct {
331+
X []byte `rlp:"nil"`
332+
}
333+
330334
type invalidTail1 struct {
331335
A uint `rlp:"tail"`
332336
B string
@@ -353,6 +357,18 @@ type tailPrivateFields struct {
353357
x, y bool
354358
}
355359

360+
type nilListUint struct {
361+
X *uint `rlp:"nilList"`
362+
}
363+
364+
type nilStringSlice struct {
365+
X *[]uint `rlp:"nilString"`
366+
}
367+
368+
type intField struct {
369+
X int
370+
}
371+
356372
var (
357373
veryBigInt = big.NewInt(0).Add(
358374
big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16),
@@ -485,20 +501,20 @@ var decodeTests = []decodeTest{
485501
error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I",
486502
},
487503
{
488-
input: "C0",
489-
ptr: new(invalidTail1),
490-
error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)",
491-
},
492-
{
493-
input: "C0",
494-
ptr: new(invalidTail2),
495-
error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)",
504+
input: "C103",
505+
ptr: new(intField),
506+
error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)",
496507
},
497508
{
498509
input: "C50102C20102",
499510
ptr: new(tailUint),
500511
error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]",
501512
},
513+
{
514+
input: "C0",
515+
ptr: new(invalidNilTag),
516+
error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`,
517+
},
502518

503519
// struct tag "tail"
504520
{
@@ -521,6 +537,16 @@ var decodeTests = []decodeTest{
521537
ptr: new(tailPrivateFields),
522538
value: tailPrivateFields{A: 1, Tail: []uint{2, 3}},
523539
},
540+
{
541+
input: "C0",
542+
ptr: new(invalidTail1),
543+
error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`,
544+
},
545+
{
546+
input: "C0",
547+
ptr: new(invalidTail2),
548+
error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`,
549+
},
524550

525551
// struct tag "-"
526552
{
@@ -529,6 +555,43 @@ var decodeTests = []decodeTest{
529555
value: hasIgnoredField{A: 1, C: 2},
530556
},
531557

558+
// struct tag "nilList"
559+
{
560+
input: "C180",
561+
ptr: new(nilListUint),
562+
error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X",
563+
},
564+
{
565+
input: "C1C0",
566+
ptr: new(nilListUint),
567+
value: nilListUint{},
568+
},
569+
{
570+
input: "C103",
571+
ptr: new(nilListUint),
572+
value: func() interface{} {
573+
v := uint(3)
574+
return nilListUint{X: &v}
575+
}(),
576+
},
577+
578+
// struct tag "nilString"
579+
{
580+
input: "C1C0",
581+
ptr: new(nilStringSlice),
582+
error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X",
583+
},
584+
{
585+
input: "C180",
586+
ptr: new(nilStringSlice),
587+
value: nilStringSlice{},
588+
},
589+
{
590+
input: "C2C103",
591+
ptr: new(nilStringSlice),
592+
value: nilStringSlice{X: &[]uint{3}},
593+
},
594+
532595
// RawValue
533596
{input: "01", ptr: new(RawValue), value: RawValue(unhex("01"))},
534597
{input: "82FFFF", ptr: new(RawValue), value: RawValue(unhex("82FFFF"))},
@@ -672,6 +735,22 @@ func TestDecodeDecoder(t *testing.T) {
672735
}
673736
}
674737

738+
func TestDecodeDecoderNilPointer(t *testing.T) {
739+
var s struct {
740+
T1 *testDecoder `rlp:"nil"`
741+
T2 *testDecoder
742+
}
743+
if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil {
744+
t.Fatalf("Decode error: %v", err)
745+
}
746+
if s.T1 != nil {
747+
t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called)
748+
}
749+
if s.T2 == nil || !s.T2.called {
750+
t.Errorf("decoder T2 not allocated/called")
751+
}
752+
}
753+
675754
type byteDecoder byte
676755

677756
func (bd *byteDecoder) DecodeRLP(s *Stream) error {

0 commit comments

Comments
 (0)