@@ -3,12 +3,16 @@ package rofl
33import (
44 "context"
55 "fmt"
6+ "io"
67 "maps"
78 "net/http"
89 "os"
910 "path/filepath"
1011 "slices"
1112 "strings"
13+ "sync"
14+ "sync/atomic"
15+ "time"
1216
1317 v1 "github.com/opencontainers/image-spec/specs-go/v1"
1418 oras "oras.land/oras-go/v2"
@@ -20,6 +24,8 @@ import (
2024
2125 "github.com/oasisprotocol/oasis-core/go/common/crypto/hash"
2226 "github.com/oasisprotocol/oasis-core/go/runtime/bundle"
27+
28+ "github.com/oasisprotocol/cli/cmd/common/progress"
2329)
2430
2531const (
@@ -31,6 +37,118 @@ const (
3137// DefaultOCIRegistry is the default OCI registry.
3238const DefaultOCIRegistry = "rofl.sh"
3339
40+ // progressBarUpdateInterval specifies how often the progress bar should be updated.
41+ const progressBarUpdateInterval = 1 * time .Second
42+
43+ // TargetWithProgress wraps oras.Target and provides updates via a progress bar.
44+ type TargetWithProgress struct {
45+ oras.Target
46+
47+ Message string
48+ UpdateInterval time.Duration
49+ BytesRead * atomic.Uint64
50+ BytesTotal uint64
51+
52+ stopUpdate chan struct {}
53+ wg sync.WaitGroup
54+ }
55+
56+ // NewTargetWithProgress creates a new TargetWithProgress.
57+ // bytesTotal is the size to use for the 100% value (use 0 if unknown).
58+ // msg is the message to display in front of the progress bar.
59+ func NewTargetWithProgress (target oras.Target , bytesTotal uint64 , msg string ) * TargetWithProgress {
60+ return & TargetWithProgress {
61+ Target : target ,
62+ Message : msg ,
63+ UpdateInterval : progressBarUpdateInterval ,
64+ BytesRead : & atomic.Uint64 {},
65+ BytesTotal : bytesTotal ,
66+ stopUpdate : make (chan struct {}),
67+ }
68+ }
69+
70+ // Push wraps the oras.Target Push method with progress bar updates.
71+ func (t * TargetWithProgress ) Push (ctx context.Context , desc v1.Descriptor , content io.Reader ) error {
72+ var progReader io.Reader
73+
74+ // Wrap the reader with our variant.
75+ pReader := & progressReader {
76+ reader : content ,
77+ bytesRead : t .BytesRead ,
78+ }
79+ progReader = pReader
80+
81+ // If the reader also has a WriteTo method, wrap it appropriately.
82+ if wt , ok := content .(io.WriterTo ); ok {
83+ progReader = & progressWriterToReader {
84+ progressReader : pReader ,
85+ writerTo : wt ,
86+ }
87+ }
88+
89+ // Do the actual push using our wrappers.
90+ return t .Target .Push (ctx , desc , progReader )
91+ }
92+
93+ // StartProgress starts updating the progress bar.
94+ func (t * TargetWithProgress ) StartProgress () {
95+ t .wg .Add (1 )
96+ go func () {
97+ defer t .wg .Done ()
98+
99+ ticker := time .NewTicker (t .UpdateInterval )
100+ defer ticker .Stop ()
101+
102+ for {
103+ select {
104+ case <- ticker .C :
105+ progress .PrintProgressBar (os .Stderr , t .Message , t .BytesRead .Load (), t .BytesTotal , false )
106+ case <- t .stopUpdate :
107+ return
108+ }
109+ }
110+ }()
111+ }
112+
113+ // StopProgress stops the progress bar updates.
114+ func (t * TargetWithProgress ) StopProgress () {
115+ if t .stopUpdate != nil {
116+ // Print the final stage of the progress bar.
117+ progress .PrintProgressBar (os .Stderr , t .Message , t .BytesRead .Load (), t .BytesTotal , true )
118+
119+ close (t .stopUpdate )
120+ t .stopUpdate = nil
121+ }
122+ t .wg .Wait ()
123+ }
124+
125+ type progressReader struct {
126+ reader io.Reader
127+ bytesRead * atomic.Uint64
128+ }
129+
130+ func (pr * progressReader ) Read (b []byte ) (int , error ) {
131+ n , err := pr .reader .Read (b )
132+ if n > 0 {
133+ pr .bytesRead .Add (uint64 (n ))
134+ }
135+ return n , err
136+ }
137+
138+ type progressWriterToReader struct {
139+ * progressReader
140+
141+ writerTo io.WriterTo
142+ }
143+
144+ func (pwr * progressWriterToReader ) WriteTo (w io.Writer ) (int64 , error ) {
145+ written , err := pwr .writerTo .WriteTo (w )
146+ if written > 0 {
147+ pwr .bytesRead .Add (uint64 (written ))
148+ }
149+ return written , err
150+ }
151+
34152// PushBundleToOciRepository pushes an ORC bundle to the given remote OCI repository.
35153//
36154// Returns the OCI manifest digest and the ORC manifest hash.
@@ -70,12 +188,26 @@ func PushBundleToOciRepository(bundleFn, dst string) (string, hash.Hash, error)
70188 return "" , hash.Hash {}, fmt .Errorf ("failed to explode bundle: %w" , err )
71189 }
72190
191+ // Keep track of the total size of the bundle as we add files to it.
192+ var totalSize int64
193+
194+ getFileSize := func (path string ) int64 {
195+ f , err := os .Stat (path )
196+ if err != nil {
197+ // This shouldn't happen, because store.Add should fail if the file doesn't exist.
198+ panic (err )
199+ }
200+ return f .Size ()
201+ }
202+
73203 // Generate the config object from the manifest.
74204 const manifestName = "META-INF/MANIFEST.MF"
75- configDsc , err := store .Add (ctx , manifestName , ociTypeOrcConfig , filepath .Join (bundleDir , manifestName ))
205+ manifestPath := filepath .Join (bundleDir , manifestName )
206+ configDsc , err := store .Add (ctx , manifestName , ociTypeOrcConfig , manifestPath )
76207 if err != nil {
77208 return "" , hash.Hash {}, fmt .Errorf ("failed to add config object from manifest: %w" , err )
78209 }
210+ totalSize += getFileSize (manifestPath )
79211
80212 // Add other files as layers.
81213 layers := make ([]v1.Descriptor , 0 , len (bnd .Data )- 1 )
@@ -86,10 +218,12 @@ func PushBundleToOciRepository(bundleFn, dst string) (string, hash.Hash, error)
86218 }
87219
88220 var layerDsc v1.Descriptor
89- layerDsc , err = store .Add (ctx , fn , ociTypeOrcLayer , filepath .Join (bundleDir , fn ))
221+ filePath := filepath .Join (bundleDir , fn )
222+ layerDsc , err = store .Add (ctx , fn , ociTypeOrcLayer , filePath )
90223 if err != nil {
91224 return "" , hash.Hash {}, fmt .Errorf ("failed to add OCI layer: %w" , err )
92225 }
226+ totalSize += getFileSize (filePath )
93227
94228 layers = append (layers , layerDsc )
95229 }
@@ -134,10 +268,20 @@ func PushBundleToOciRepository(bundleFn, dst string) (string, hash.Hash, error)
134268 }
135269 repo .Client = client
136270
271+ repoSize := uint64 (totalSize ) //nolint:gosec
272+ repoWithProgress := NewTargetWithProgress (repo , repoSize , "Pushing..." )
273+ repoWithProgress .StartProgress ()
274+ defer repoWithProgress .StopProgress ()
275+
137276 // Push to remote repository.
138- if _ , err = oras .Copy (ctx , store , tag , repo , tag , oras .DefaultCopyOptions ); err != nil {
277+ if _ , err = oras .Copy (ctx , store , tag , repoWithProgress , tag , oras .DefaultCopyOptions ); err != nil {
139278 return "" , hash.Hash {}, fmt .Errorf ("failed to push to remote OCI repository: %w" , err )
140279 }
141280
281+ // Force progress to 100% in case the ORC already exists on the remote.
282+ // This is necessary so we get a full progressbar instead of an empty one
283+ // when we're done.
284+ repoWithProgress .BytesRead .Store (repoWithProgress .BytesTotal )
285+
142286 return manifestDescriptor .Digest .String (), bnd .Manifest .Hash (), nil
143287}
0 commit comments