diff --git a/digest.go b/digest.go index 3c473ac..1aeb279 100644 --- a/digest.go +++ b/digest.go @@ -157,3 +157,21 @@ func (d Digest) Hex() string { func (d Digest) String() string { return string(d) } + +// UnmarshalText implements encoding.TextUnmarshaler (and affects encoding/json unmarshaling) +// +// This enforces that unmarshaled values are valid; otherwise Go would allow setting a Digest value to arbitrary strings, +// causing a later panic or other misuse if users forget to call Validate(). +func (d *Digest) UnmarshalText(text []byte) error { + if len(text) == 0 { // This frequently happens in `json:",omitempty"` fields, and users are presumably ready to handle that. + *d = "" + return nil + } + + value, err := Parse(string(text)) + if err != nil { + return err + } + *d = value + return nil +} diff --git a/digest_test.go b/digest_test.go index 1c8ffc8..9a34482 100644 --- a/digest_test.go +++ b/digest_test.go @@ -15,6 +15,9 @@ package digest_test import ( + "encoding" + "encoding/json" + "errors" "testing" "github.com/opencontainers/go-digest" @@ -118,3 +121,78 @@ func TestParseDigest(t *testing.T) { }) } } + +func TestDigestUnmarshalJSONValue(t *testing.T) { + var _ encoding.TextUnmarshaler = (*digest.Digest)(nil) + + for _, tc := range []struct { + name string + input string + expectedValue string + expectedError error + }{ + { + name: "empty value", + input: `{"digest":""}`, + expectedValue: "", + expectedError: nil, + }, + { + name: "no value", + input: `{}`, + expectedValue: "", + expectedError: nil, + }, + { + name: "success", + input: `{"digest":"sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`, + expectedValue: "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + expectedError: nil, + }, + { + name: "no colon", + input: `{"digest":"no-colon"}`, + expectedError: digest.ErrDigestInvalidFormat, + }, + { + name: "invalid algorithm", + input: `{"digest":"../../../../etc/issue:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`, + expectedValue: "", + expectedError: digest.ErrDigestInvalidFormat, + }, + { + name: "invalid value", + input: `{"digest":"sha256:../../../../etc/issue"}`, + expectedError: digest.ErrDigestInvalidLength, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + var dest struct { + Digest digest.Digest `json:"digest"` + } + err := json.Unmarshal([]byte(tc.input), &dest) + if tc.expectedError != nil { + if err == nil || !errors.Is(err, tc.expectedError) { + t.Fatalf("Unexpected error %#v", err) + } + } else { + if err != nil { + t.Fatalf("Unexpected error %#v", err) + } + if dest.Digest.String() != tc.expectedValue { + t.Fatalf("Unexpected value %q", dest.Digest.String()) + } + if tc.expectedValue != "" { + // Just to be extra sure: The value is valid… + if err := dest.Digest.Validate(); err != nil { + t.Fatalf("Successfully unmarshaled invalid value: %v", err) + } + // … and does not panic + _ = dest.Digest.Algorithm() + _ = dest.Digest.Hex() + } + } + }) + } +}