@@ -22,14 +22,15 @@ import (
2222 "fmt"
2323 "io"
2424
25+ retry "github.com/avast/retry-go/v4"
26+ sha256 "github.com/minio/sha256-simd"
27+ ocispec "github.com/opencontainers/image-spec/specs-go/v1"
28+ "golang.org/x/sync/errgroup"
29+
2530 internalpb "github.com/CloudNativeAI/modctl/internal/pb"
2631 "github.com/CloudNativeAI/modctl/pkg/backend/remote"
2732 "github.com/CloudNativeAI/modctl/pkg/config"
2833 "github.com/CloudNativeAI/modctl/pkg/storage"
29-
30- retry "github.com/avast/retry-go/v4"
31- ocispec "github.com/opencontainers/image-spec/specs-go/v1"
32- "golang.org/x/sync/errgroup"
3334)
3435
3536// Pull pulls an artifact from a registry.
@@ -136,6 +137,8 @@ func pullIfNotExist(ctx context.Context, pb *internalpb.ProgressBar, prompt stri
136137 defer content .Close ()
137138
138139 reader := pb .Add (prompt , desc .Digest .String (), desc .Size , content )
140+ hash := sha256 .New ()
141+ reader = io .TeeReader (reader , hash )
139142
140143 // push the content to the destination, and wrap the content reader for progress bar,
141144 // manifest should use dst.Manifests().Push, others should use dst.Blobs().Push.
@@ -185,6 +188,13 @@ func pullIfNotExist(ctx context.Context, pb *internalpb.ProgressBar, prompt stri
185188 }
186189 }
187190
191+ // validate the digest of the blob.
192+ if err := validateDigest (desc .Digest .String (), hash .Sum (nil )); err != nil {
193+ err = fmt .Errorf ("failed to validate the digest of the blob %s, err: %w" , desc .Digest .String (), err )
194+ pb .Abort (desc .Digest .String (), err )
195+ return err
196+ }
197+
188198 return nil
189199}
190200
@@ -199,11 +209,38 @@ func pullAndExtractFromRemote(ctx context.Context, pb *internalpb.ProgressBar, p
199209 defer content .Close ()
200210
201211 reader := pb .Add (prompt , desc .Digest .String (), desc .Size , content )
212+ hash := sha256 .New ()
213+ reader = io .TeeReader (reader , hash )
214+
202215 if err := extractLayer (desc , outputDir , reader ); err != nil {
203216 err = fmt .Errorf ("failed to extract the blob %s to output directory: %w" , desc .Digest .String (), err )
204217 pb .Abort (desc .Digest .String (), err )
205218 return err
206219 }
207220
221+ // validate the digest of the blob.
222+ if err := validateDigest (desc .Digest .String (), hash .Sum (nil )); err != nil {
223+ err = fmt .Errorf ("failed to validate the digest of the blob %s, err: %w" , desc .Digest .String (), err )
224+ pb .Abort (desc .Digest .String (), err )
225+ return err
226+ }
227+
228+ return nil
229+ }
230+
231+ // validateDigest validates the hash digest whether matches the expected digest.
232+ func validateDigest (digest string , hash []byte ) error {
233+ if digest == "" {
234+ return fmt .Errorf ("digest is empty" )
235+ }
236+
237+ if len (hash ) != sha256 .Size {
238+ return fmt .Errorf ("invalid hash length" )
239+ }
240+
241+ if digest != fmt .Sprintf ("sha256:%x" , hash ) {
242+ return fmt .Errorf ("actual digest %s does not match the expected digest %s" , fmt .Sprintf ("sha256:%x" , hash ), digest )
243+ }
244+
208245 return nil
209246}
0 commit comments