Skip to content

Commit 65e3baa

Browse files
committed
feat: wrap parsing errors into ErrInvalidCid
1 parent 85c4236 commit 65e3baa

File tree

2 files changed

+69
-25
lines changed

2 files changed

+69
-25
lines changed

cid.go

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,32 @@ import (
3737
// UnsupportedVersionString just holds an error message
3838
const UnsupportedVersionString = "<unsupported cid version>"
3939

40+
// ErrInvalidCid is an error that indicates that a CID is invalid.
41+
type ErrInvalidCid struct {
42+
Err error
43+
}
44+
45+
func (e *ErrInvalidCid) Error() string {
46+
return fmt.Sprintf("invalid cid: %s", e.Err)
47+
}
48+
49+
func (e *ErrInvalidCid) Unwrap() error {
50+
return e.Err
51+
}
52+
53+
func (e *ErrInvalidCid) Is(err error) bool {
54+
switch err.(type) {
55+
case *ErrInvalidCid:
56+
return true
57+
default:
58+
return false
59+
}
60+
}
61+
4062
var (
4163
// ErrCidTooShort means that the cid passed to decode was not long
4264
// enough to be a valid Cid
43-
ErrCidTooShort = errors.New("cid too short")
65+
ErrCidTooShort = &ErrInvalidCid{errors.New("cid too short")}
4466

4567
// ErrInvalidEncoding means that selected encoding is not supported
4668
// by this Cid version
@@ -90,10 +112,10 @@ func tryNewCidV0(mhash mh.Multihash) (Cid, error) {
90112
// incorrectly detect it as CidV1 in the Version() method
91113
dec, err := mh.Decode(mhash)
92114
if err != nil {
93-
return Undef, err
115+
return Undef, &ErrInvalidCid{err}
94116
}
95117
if dec.Code != mh.SHA2_256 || dec.Length != 32 {
96-
return Undef, fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)
118+
return Undef, &ErrInvalidCid{fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)}
97119
}
98120
return Cid{string(mhash)}, nil
99121
}
@@ -177,7 +199,7 @@ func Parse(v interface{}) (Cid, error) {
177199
case Cid:
178200
return v2, nil
179201
default:
180-
return Undef, fmt.Errorf("can't parse %+v as Cid", v2)
202+
return Undef, &ErrInvalidCid{fmt.Errorf("can't parse %+v as Cid", v2)}
181203
}
182204
}
183205

@@ -210,15 +232,15 @@ func Decode(v string) (Cid, error) {
210232
if len(v) == 46 && v[:2] == "Qm" {
211233
hash, err := mh.FromB58String(v)
212234
if err != nil {
213-
return Undef, err
235+
return Undef, &ErrInvalidCid{err}
214236
}
215237

216238
return tryNewCidV0(hash)
217239
}
218240

219241
_, data, err := mbase.Decode(v)
220242
if err != nil {
221-
return Undef, err
243+
return Undef, &ErrInvalidCid{err}
222244
}
223245

224246
return Cast(data)
@@ -240,7 +262,7 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
240262
// check encoding is valid
241263
_, err := mbase.NewEncoder(encoding)
242264
if err != nil {
243-
return -1, err
265+
return -1, &ErrInvalidCid{err}
244266
}
245267

246268
return encoding, nil
@@ -260,11 +282,11 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
260282
func Cast(data []byte) (Cid, error) {
261283
nr, c, err := CidFromBytes(data)
262284
if err != nil {
263-
return Undef, err
285+
return Undef, &ErrInvalidCid{err}
264286
}
265287

266288
if nr != len(data) {
267-
return Undef, fmt.Errorf("trailing bytes in data buffer passed to cid Cast")
289+
return Undef, &ErrInvalidCid{fmt.Errorf("trailing bytes in data buffer passed to cid Cast")}
268290
}
269291

270292
return c, nil
@@ -615,34 +637,34 @@ func PrefixFromBytes(buf []byte) (Prefix, error) {
615637
func CidFromBytes(data []byte) (int, Cid, error) {
616638
if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 {
617639
if len(data) < 34 {
618-
return 0, Undef, fmt.Errorf("not enough bytes for cid v0")
640+
return 0, Undef, &ErrInvalidCid{fmt.Errorf("not enough bytes for cid v0")}
619641
}
620642

621643
h, err := mh.Cast(data[:34])
622644
if err != nil {
623-
return 0, Undef, err
645+
return 0, Undef, &ErrInvalidCid{err}
624646
}
625647

626648
return 34, Cid{string(h)}, nil
627649
}
628650

629651
vers, n, err := varint.FromUvarint(data)
630652
if err != nil {
631-
return 0, Undef, err
653+
return 0, Undef, &ErrInvalidCid{err}
632654
}
633655

634656
if vers != 1 {
635-
return 0, Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
657+
return 0, Undef, &ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
636658
}
637659

638660
_, cn, err := varint.FromUvarint(data[n:])
639661
if err != nil {
640-
return 0, Undef, err
662+
return 0, Undef, &ErrInvalidCid{err}
641663
}
642664

643665
mhnr, _, err := mh.MHFromBytes(data[n+cn:])
644666
if err != nil {
645-
return 0, Undef, err
667+
return 0, Undef, &ErrInvalidCid{err}
646668
}
647669

648670
l := n + cn + mhnr
@@ -705,32 +727,32 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
705727
// The varint package wants a io.ByteReader, so we must wrap our io.Reader.
706728
vers, err := varint.ReadUvarint(br)
707729
if err != nil {
708-
return len(br.dst), Undef, err
730+
return len(br.dst), Undef, &ErrInvalidCid{err}
709731
}
710732

711733
// If we have a CIDv0, read the rest of the bytes and cast the buffer.
712734
if vers == mh.SHA2_256 {
713735
if n, err := io.ReadFull(r, br.dst[1:34]); err != nil {
714-
return len(br.dst) + n, Undef, err
736+
return len(br.dst) + n, Undef, &ErrInvalidCid{err}
715737
}
716738

717739
br.dst = br.dst[:34]
718740
h, err := mh.Cast(br.dst)
719741
if err != nil {
720-
return len(br.dst), Undef, err
742+
return len(br.dst), Undef, &ErrInvalidCid{err}
721743
}
722744

723745
return len(br.dst), Cid{string(h)}, nil
724746
}
725747

726748
if vers != 1 {
727-
return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
749+
return len(br.dst), Undef, &ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
728750
}
729751

730752
// CID block encoding multicodec.
731753
_, err = varint.ReadUvarint(br)
732754
if err != nil {
733-
return len(br.dst), Undef, err
755+
return len(br.dst), Undef, &ErrInvalidCid{err}
734756
}
735757

736758
// We could replace most of the code below with go-multihash's ReadMultihash.
@@ -741,19 +763,19 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
741763
// Multihash hash function code.
742764
_, err = varint.ReadUvarint(br)
743765
if err != nil {
744-
return len(br.dst), Undef, err
766+
return len(br.dst), Undef, &ErrInvalidCid{err}
745767
}
746768

747769
// Multihash digest length.
748770
mhl, err := varint.ReadUvarint(br)
749771
if err != nil {
750-
return len(br.dst), Undef, err
772+
return len(br.dst), Undef, &ErrInvalidCid{err}
751773
}
752774

753775
// Refuse to make large allocations to prevent OOMs due to bugs.
754776
const maxDigestAlloc = 32 << 20 // 32MiB
755777
if mhl > maxDigestAlloc {
756-
return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)
778+
return len(br.dst), Undef, &ErrInvalidCid{fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)}
757779
}
758780

759781
// Fine to convert mhl to int, given maxDigestAlloc.
@@ -772,15 +794,15 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
772794
if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil {
773795
// We can't use len(br.dst) here,
774796
// as we've only read n bytes past prefixLength.
775-
return prefixLength + n, Undef, err
797+
return prefixLength + n, Undef, &ErrInvalidCid{err}
776798
}
777799

778800
// This simply ensures the multihash is valid.
779801
// TODO: consider removing this bit, as it's probably redundant;
780802
// for now, it helps ensure consistency with CidFromBytes.
781803
_, _, err = mh.MHFromBytes(br.dst[mhStart:])
782804
if err != nil {
783-
return len(br.dst), Undef, err
805+
return len(br.dst), Undef, &ErrInvalidCid{err}
784806
}
785807

786808
return len(br.dst), Cid{string(br.dst)}, nil

cid_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
crand "crypto/rand"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"io"
910
"math/rand"
@@ -227,6 +228,9 @@ func TestEmptyString(t *testing.T) {
227228
if err == nil {
228229
t.Fatal("shouldnt be able to parse an empty cid")
229230
}
231+
if !errors.Is(err, &ErrInvalidCid{}) {
232+
t.Fatal("error must be ErrInvalidCid")
233+
}
230234
}
231235

232236
func TestV0Handling(t *testing.T) {
@@ -282,6 +286,9 @@ func TestV0ErrorCases(t *testing.T) {
282286
if err == nil {
283287
t.Fatal("should have failed to decode that ref")
284288
}
289+
if !errors.Is(err, &ErrInvalidCid{}) {
290+
t.Fatal("error must be ErrInvalidCid")
291+
}
285292
}
286293

287294
func TestNewPrefixV1(t *testing.T) {
@@ -749,6 +756,9 @@ func TestBadParse(t *testing.T) {
749756
if err == nil {
750757
t.Fatal("expected to fail to parse an invalid CIDv1 CID")
751758
}
759+
if !errors.Is(err, &ErrInvalidCid{}) {
760+
t.Fatal("error must be ErrInvalidCid")
761+
}
752762
}
753763

754764
func TestLoggable(t *testing.T) {
@@ -763,3 +773,15 @@ func TestLoggable(t *testing.T) {
763773
t.Fatalf("did not get expected loggable form (got %v)", actual)
764774
}
765775
}
776+
777+
func TestErrInvalidCid(t *testing.T) {
778+
_, err := Decode("not-a-cid")
779+
if err == nil {
780+
t.Fatal("expected error")
781+
}
782+
783+
is := errors.Is(err, &ErrInvalidCid{})
784+
if !is {
785+
t.Fatal("expected error to be ErrInvalidCid")
786+
}
787+
}

0 commit comments

Comments
 (0)