diff --git a/pkg/fetch/http_fetcher.go b/pkg/fetch/http_fetcher.go index 8a4910d..93ffa38 100644 --- a/pkg/fetch/http_fetcher.go +++ b/pkg/fetch/http_fetcher.go @@ -173,7 +173,8 @@ func (hf *httpFetcher) downloadBlob(ctx context.Context, uri string, digestFunct return buffer.NewCASBufferFromByteSlice(digest, bodyBytes, buffer.UserProvided), digest } -// getChecksumSri parses the checksum.sri qualifier into an expected digest and a digest function to use +// getChecksumSri parses the checksum.sri qualifier into an expected digest and a digest function to use. +// Per the SRI spec, multiple checksums may be space-separated; the first supported one is used. func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, bb_digest.Function, error) { hashTypes := map[string]remoteexecution.DigestFunction_Value{ "sha256": remoteexecution.DigestFunction_SHA256, @@ -183,41 +184,35 @@ func getChecksumSri(qualifiers []*remoteasset.Qualifier) (string, bb_digest.Func "sha512": remoteexecution.DigestFunction_SHA512, "sha256tree": remoteexecution.DigestFunction_SHA256TREE, } - expectedDigest := "" - digestFunctionEnum := remoteexecution.DigestFunction_UNKNOWN for _, qualifier := range qualifiers { if qualifier.Name == "checksum.sri" { - if digestFunctionEnum != remoteexecution.DigestFunction_UNKNOWN { - return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Multiple checksum.sri provided") + for _, checksum := range strings.Fields(qualifier.Value) { + parts := strings.SplitN(checksum, "-", 2) + if len(parts) != 2 { + continue + } + hashName := parts[0] + b64hash := parts[1] + + digestFunctionEnum, ok := hashTypes[hashName] + if !ok { + continue + } + + decoded, err := base64.StdEncoding.DecodeString(b64hash) + if err != nil { + continue + } + expectedDigest := hex.EncodeToString(decoded) + + instance := util.Must(bb_digest.NewInstanceName("")) + checksumFunction, err := instance.GetDigestFunction(digestFunctionEnum, len(decoded)) + if err != nil { + continue + } + return expectedDigest, checksumFunction, nil } - parts := strings.SplitN(qualifier.Value, "-", 2) - if len(parts) != 2 { - return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Bad checksum.sri hash expression: %s", qualifier.Value) - } - hashName := parts[0] - b64hash := parts[1] - - digestFunctionEnum, ok := hashTypes[hashName] - if !ok { - return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Unsupported checksum algorithm %s", hashName) - } - - // Convert expected digest to hex - decoded, err := base64.StdEncoding.DecodeString(b64hash) - if err != nil { - return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Failed to decode checksum as base64 encoded %s sum: %s", hashName, err.Error()) - } - expectedDigest = hex.EncodeToString(decoded) - - // Convert to a proper digest function. - // Note: The Instance name doesn't matter here, this function is used only - // to give us a convenient API when actually checking the checksum. - instance := util.Must(bb_digest.NewInstanceName("")) - checksumFunction, err := instance.GetDigestFunction(digestFunctionEnum, len(expectedDigest)) - if err != nil { - return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "Failed to get checksum function for checksum.sri: %s", err.Error()) - } - return expectedDigest, checksumFunction, nil + return "", bb_digest.Function{}, status.Errorf(codes.InvalidArgument, "No supported checksum in checksum.sri qualifier") } } diff --git a/pkg/fetch/http_fetcher_test.go b/pkg/fetch/http_fetcher_test.go index a71c5bc..64daabe 100644 --- a/pkg/fetch/http_fetcher_test.go +++ b/pkg/fetch/http_fetcher_test.go @@ -242,7 +242,7 @@ func TestHTTPFetcherFetchBlob(t *testing.T) { require.Equal(t, response.Status.Code, int32(codes.OK)) }) - t.Run("UnknownChecksumSriAlgo", func(t *testing.T) { + t.Run("NoSupportedChecksum", func(t *testing.T) { request := &remoteasset.FetchBlobRequest{ InstanceName: InstanceName, Uris: []string{uri, "www.another.com"}, @@ -255,42 +255,35 @@ func TestHTTPFetcherFetchBlob(t *testing.T) { } response, err := HTTPFetcher.FetchBlob(ctx, request) - testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Unsupported checksum algorithm sha0"), err) + testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "No supported checksum in checksum.sri qualifier"), err) require.Nil(t, response) }) - t.Run("BadChecksumSriAlgo", func(t *testing.T) { + t.Run("MultipleChecksumsFirstUnsupported", func(t *testing.T) { + // First checksum uses unsupported algo, second is valid - should succeed request := &remoteasset.FetchBlobRequest{ InstanceName: InstanceName, - Uris: []string{uri, "www.another.com"}, - Qualifiers: []*remoteasset.Qualifier{ - { - Name: "checksum.sri", - Value: "no_dash", - }, - }, - } - - response, err := HTTPFetcher.FetchBlob(ctx, request) - testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Bad checksum.sri hash expression: no_dash"), err) - require.Nil(t, response) - }) - - t.Run("BadChecksumSriBase64Value", func(t *testing.T) { - request := &remoteasset.FetchBlobRequest{ - InstanceName: InstanceName, - Uris: []string{uri, "www.another.com"}, + Uris: []string{uri}, Qualifiers: []*remoteasset.Qualifier{ { Name: "checksum.sri", - Value: "sha256-no-base64", + Value: "sha0-invalid " + digestToChecksumSri(remoteexecution.DigestFunction_SHA256, helloDigest), }, }, } + body := io.NopCloser(bytes.NewBuffer([]byte(TestData))) + httpDoCall := roundTripper.EXPECT().RoundTrip(gomock.Any()).Return(&http.Response{ + Status: "200 Success", + StatusCode: 200, + Body: body, + ContentLength: 5, + }, nil) + casBlobAccess.EXPECT().Put(ctx, helloDigest, gomock.Any()).Return(nil).After(httpDoCall) response, err := HTTPFetcher.FetchBlob(ctx, request) - testutil.RequireEqualStatus(t, status.Error(codes.InvalidArgument, "Failed to decode checksum as base64 encoded sha256 sum: illegal base64 data at input byte 2"), err) - require.Nil(t, response) + require.NoError(t, err) + require.True(t, proto.Equal(response.BlobDigest, helloDigest.GetProto())) + require.Equal(t, response.Status.Code, int32(codes.OK)) }) t.Run("OneFailOneSuccess", func(t *testing.T) {