Skip to content

Commit d88624b

Browse files
chlinsBraveY
authored andcommitted
feat: validate the sha256 after pulling the blob
Signed-off-by: chlins <[email protected]>
1 parent c7f486c commit d88624b

File tree

4 files changed

+67
-16
lines changed

4 files changed

+67
-16
lines changed

pkg/backend/fetch_test.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ package backend
1919
import (
2020
"context"
2121
"encoding/json"
22+
"fmt"
2223
"net/http"
2324
"net/http/httptest"
2425
"os"
2526
"strings"
2627
"testing"
2728

2829
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
30+
godigest "github.com/opencontainers/go-digest"
2931
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
3032
"github.com/stretchr/testify/assert"
3133
"github.com/stretchr/testify/require"
@@ -39,6 +41,15 @@ func TestFetch(t *testing.T) {
3941
require.NoError(t, err)
4042
defer os.RemoveAll(tempDir)
4143

44+
// Setup mock file
45+
const (
46+
file1Content = "file1 content..."
47+
file2Content = "file2 content..."
48+
)
49+
50+
file1Digest := godigest.FromString(file1Content)
51+
file2Digest := godigest.FromString(file2Content)
52+
4253
// Setup mock server
4354
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4455
switch r.URL.Path {
@@ -51,16 +62,16 @@ func TestFetch(t *testing.T) {
5162
Layers: []ocispec.Descriptor{
5263
{
5364
MediaType: "application/octet-stream.raw",
54-
Digest: "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
55-
Size: 0,
65+
Digest: file1Digest,
66+
Size: int64(len(file1Content)),
5667
Annotations: map[string]string{
5768
modelspec.AnnotationFilepath: "file1.txt",
5869
},
5970
},
6071
{
6172
MediaType: "application/octet-stream.raw",
62-
Digest: "sha256:a3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
63-
Size: 0,
73+
Digest: file2Digest,
74+
Size: int64(len(file2Content)),
6475
Annotations: map[string]string{
6576
modelspec.AnnotationFilepath: "file2.txt",
6677
},
@@ -69,10 +80,13 @@ func TestFetch(t *testing.T) {
6980
}
7081
w.Header().Set("Content-Type", "application/json")
7182
require.NoError(t, json.NewEncoder(w).Encode(manifest))
72-
case "/v2/test/model/blobs/sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
73-
"/v2/test/model/blobs/sha256:a3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
74-
// Return empty content for blobs
75-
w.WriteHeader(http.StatusOK)
83+
case fmt.Sprintf("/v2/test/model/blobs/%s", file1Digest):
84+
_, err := w.Write([]byte(file1Content))
85+
require.NoError(t, err)
86+
87+
case fmt.Sprintf("/v2/test/model/blobs/%s", file2Digest):
88+
_, err := w.Write([]byte(file2Content))
89+
require.NoError(t, err)
7690
default:
7791
t.Logf("Unexpected request to %s", r.URL.Path)
7892
w.WriteHeader(http.StatusNotFound)

pkg/backend/processor/options.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ func WithProgressTracker(tracker *pb.ProgressBar) ProcessOption {
4848
var retryOpts = []retry.Option{
4949
retry.Attempts(3),
5050
retry.DelayType(retry.BackOffDelay),
51-
retry.Delay(1 * time.Second),
52-
retry.MaxDelay(5 * time.Second),
51+
retry.Delay(5 * time.Second),
52+
retry.MaxDelay(10 * time.Second),
5353
}

pkg/backend/pull.go

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@ import (
2222
"fmt"
2323
"io"
2424

25+
retry "github.com/avast/retry-go/v4"
26+
sha256 "github.com/minio/sha256-simd"
27+
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
28+
"golang.org/x/sync/errgroup"
29+
2530
internalpb "github.com/CloudNativeAI/modctl/internal/pb"
2631
"github.com/CloudNativeAI/modctl/pkg/backend/remote"
2732
"github.com/CloudNativeAI/modctl/pkg/config"
2833
"github.com/CloudNativeAI/modctl/pkg/storage"
29-
30-
retry "github.com/avast/retry-go/v4"
31-
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
32-
"golang.org/x/sync/errgroup"
3334
)
3435

3536
// Pull pulls an artifact from a registry.
@@ -136,6 +137,8 @@ func pullIfNotExist(ctx context.Context, pb *internalpb.ProgressBar, prompt stri
136137
defer content.Close()
137138

138139
reader := pb.Add(prompt, desc.Digest.String(), desc.Size, content)
140+
hash := sha256.New()
141+
reader = io.TeeReader(reader, hash)
139142

140143
// push the content to the destination, and wrap the content reader for progress bar,
141144
// manifest should use dst.Manifests().Push, others should use dst.Blobs().Push.
@@ -185,6 +188,13 @@ func pullIfNotExist(ctx context.Context, pb *internalpb.ProgressBar, prompt stri
185188
}
186189
}
187190

191+
// validate the digest of the blob.
192+
if err := validateDigest(desc.Digest.String(), hash.Sum(nil)); err != nil {
193+
err = fmt.Errorf("failed to validate the digest of the blob %s, err: %w", desc.Digest.String(), err)
194+
pb.Abort(desc.Digest.String(), err)
195+
return err
196+
}
197+
188198
return nil
189199
}
190200

@@ -199,11 +209,38 @@ func pullAndExtractFromRemote(ctx context.Context, pb *internalpb.ProgressBar, p
199209
defer content.Close()
200210

201211
reader := pb.Add(prompt, desc.Digest.String(), desc.Size, content)
212+
hash := sha256.New()
213+
reader = io.TeeReader(reader, hash)
214+
202215
if err := extractLayer(desc, outputDir, reader); err != nil {
203216
err = fmt.Errorf("failed to extract the blob %s to output directory: %w", desc.Digest.String(), err)
204217
pb.Abort(desc.Digest.String(), err)
205218
return err
206219
}
207220

221+
// validate the digest of the blob.
222+
if err := validateDigest(desc.Digest.String(), hash.Sum(nil)); err != nil {
223+
err = fmt.Errorf("failed to validate the digest of the blob %s, err: %w", desc.Digest.String(), err)
224+
pb.Abort(desc.Digest.String(), err)
225+
return err
226+
}
227+
228+
return nil
229+
}
230+
231+
// validateDigest validates the hash digest whether matches the expected digest.
232+
func validateDigest(digest string, hash []byte) error {
233+
if digest == "" {
234+
return fmt.Errorf("digest is empty")
235+
}
236+
237+
if len(hash) != sha256.Size {
238+
return fmt.Errorf("invalid hash length")
239+
}
240+
241+
if digest != fmt.Sprintf("sha256:%x", hash) {
242+
return fmt.Errorf("actual digest %s does not match the expected digest %s", fmt.Sprintf("sha256:%x", hash), digest)
243+
}
244+
208245
return nil
209246
}

pkg/backend/retry.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ import (
2525
var retryOpts = []retry.Option{
2626
retry.Attempts(3),
2727
retry.DelayType(retry.BackOffDelay),
28-
retry.Delay(1 * time.Second),
29-
retry.MaxDelay(5 * time.Second),
28+
retry.Delay(5 * time.Second),
29+
retry.MaxDelay(10 * time.Second),
3030
}

0 commit comments

Comments
 (0)