Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 28 additions & 33 deletions pkg/fetch/http_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, digestFunct
return buffer.NewCASBufferFromByteSlice(digest, bodyBytes, buffer.UserProvided), digest
}

// getChecksumSri parses the checksum.sri qualifier into an expected digest and a digest function to use
// getChecksumSri parses the checksum.sri qualifier into an expected digest and a digest function to use.
// Per the SRI spec, multiple checksums may be space-separated; the first supported one is used.
func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, bb_digest.Function, error) {
hashTypes := map[string]remoteexecution.DigestFunction_Value{
"sha256": remoteexecution.DigestFunction_SHA256,
Expand All @@ -183,41 +184,35 @@ func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, bb_digest.Func
"sha512": remoteexecution.DigestFunction_SHA512,
"sha256tree": remoteexecution.DigestFunction_SHA256TREE,
}
expectedDigest := ""
digestFunctionEnum := remoteexecution.DigestFunction_UNKNOWN
for _, qualifier := range qualifiers {
if qualifier.Name == "checksum.sri" {
if digestFunctionEnum != remoteexecution.DigestFunction_UNKNOWN {
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Multiple checksum.sri provided")
for _, checksum := range strings.Fields(qualifier.Value) {
parts := strings.SplitN(checksum, "-", 2)
if len(parts) != 2 {
continue
}
hashName := parts[0]
b64hash := parts[1]

digestFunctionEnum, ok := hashTypes[hashName]
if !ok {
continue
}

decoded, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
continue
}
expectedDigest := hex.EncodeToString(decoded)

instance := util.Must(bb_digest.NewInstanceName(""))
checksumFunction, err := instance.GetDigestFunction(digestFunctionEnum, len(decoded))
if err != nil {
continue
}
return expectedDigest, checksumFunction, nil
}
parts := strings.SplitN(qualifier.Value, "-", 2)
if len(parts) != 2 {
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Bad checksum.sri hash expression: %s", qualifier.Value)
}
hashName := parts[0]
b64hash := parts[1]

digestFunctionEnum, ok := hashTypes[hashName]
if !ok {
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Unsupported checksum algorithm %s", hashName)
}

// Convert expected digest to hex
decoded, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Failed to decode checksum as base64 encoded %s sum: %s", hashName, err.Error())
}
expectedDigest = hex.EncodeToString(decoded)

// Convert to a proper digest function.
// Note: The Instance name doesn't matter here, this function is used only
// to give us a convenient API when actually checking the checksum.
instance := util.Must(bb_digest.NewInstanceName(""))
checksumFunction, err := instance.GetDigestFunction(digestFunctionEnum, len(expectedDigest))
if err != nil {
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Failed to get checksum function for checksum.sri: %s", err.Error())
}
return expectedDigest, checksumFunction, nil
return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "No supported checksum in checksum.sri qualifier")
}
}

Expand Down
41 changes: 17 additions & 24 deletions pkg/fetch/http_fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
require.Equal(t, response.Status.Code, int32(codes.OK))
})

t.Run("UnknownChecksumSriAlgo", func(t *testing.T) {
t.Run("NoSupportedChecksum", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: InstanceName,
Uris: []string{uri, "www.another.com"},
Expand All @@ -255,42 +255,35 @@ func TestHTTPFetcherFetchBlob(t *testing.T) {
}

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Unsupported checksum algorithm sha0"), err)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "No supported checksum in checksum.sri qualifier"), err)
require.Nil(t, response)
})

t.Run("BadChecksumSriAlgo", func(t *testing.T) {
t.Run("MultipleChecksumsFirstUnsupported", func(t *testing.T) {
// First checksum uses unsupported algo, second is valid - should succeed
request := &remoteasset.FetchBlobRequest{
InstanceName: InstanceName,
Uris: []string{uri, "www.another.com"},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "no_dash",
},
},
}

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Bad checksum.sri hash expression: no_dash"), err)
require.Nil(t, response)
})

t.Run("BadChecksumSriBase64Value", func(t *testing.T) {
request := &remoteasset.FetchBlobRequest{
InstanceName: InstanceName,
Uris: []string{uri, "www.another.com"},
Uris: []string{uri},
Qualifiers: []*remoteasset.Qualifier{
{
Name: "checksum.sri",
Value: "sha256-no-base64",
Value: "sha0-invalid " + digestToChecksumSri(remoteexecution.DigestFunction_SHA256, helloDigest),
},
},
}
body := io.NopCloser(bytes.NewBuffer([]byte(TestData)))
httpDoCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{
Status: "200 Success",
StatusCode: 200,
Body: body,
ContentLength: 5,
}, nil)
casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall)

response, err := HTTPFetcher.FetchBlob(ctx, request)
testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Failed to decode checksum as base64 encoded sha256 sum: illegal base64 data at input byte 2"), err)
require.Nil(t, response)
require.NoError(t, err)
require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto()))
require.Equal(t, response.Status.Code, int32(codes.OK))
})

t.Run("OneFailOneSuccess", func(t *testing.T) {
Expand Down