Skip to content

Commit 4784286

Browse files
committed
Add progress bar
And remove existing transport implementation We are gonna implement this in the go-containerregistry layer it's cleaner Signed-off-by: Eric Curtin <[email protected]>
1 parent 7fdb650 commit 4784286

File tree

27 files changed

+302
-6463
lines changed

27 files changed

+302
-6463
lines changed

cmd/cli/commands/compose.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string
175175
}
176176
return false
177177
}) {
178-
_, _, err = desktopClient.Pull(model, false, func(s string) {
178+
printer := desktop.NewSimplePrinter(func(s string) {
179179
_ = sendInfo(s)
180180
})
181+
_, _, err = desktopClient.Pull(model, false, printer)
181182
if err != nil {
182183
_ = sendErrorf("Failed to pull model: %v", err)
183184
return fmt.Errorf("Failed to pull model: %v\n", err)

cmd/cli/commands/package.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error {
379379
}
380380

381381
// Print progress messages
382-
TUIProgress(progressMsg.Message)
382+
fmt.Print("\r\033[K", progressMsg.Message)
383383
}
384384
cmd.PrintErrln("") // newline after progress
385385

cmd/cli/commands/pull.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ package commands
22

33
import (
44
"fmt"
5-
"os"
65

76
"github.com/docker/model-runner/cmd/cli/commands/completion"
87
"github.com/docker/model-runner/cmd/cli/desktop"
98

10-
"github.com/mattn/go-isatty"
119
"github.com/spf13/cobra"
1210
)
1311

@@ -42,13 +40,8 @@ func newPullCmd() *cobra.Command {
4240
}
4341

4442
func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error {
45-
var progress func(string)
46-
if isatty.IsTerminal(os.Stdout.Fd()) {
47-
progress = TUIProgress
48-
} else {
49-
progress = RawProgress
50-
}
51-
response, progressShown, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, progress)
43+
printer := desktop.NewCobraPrinter()
44+
response, progressShown, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, printer)
5245

5346
// Add a newline before any output (success or error) if progress was shown.
5447
if progressShown {
@@ -62,11 +55,3 @@ func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string,
6255
cmd.Println(response)
6356
return nil
6457
}
65-
66-
func TUIProgress(message string) {
67-
fmt.Print("\r\033[K", message)
68-
}
69-
70-
func RawProgress(message string) {
71-
fmt.Println(message)
72-
}

cmd/cli/commands/push.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ func newPushCmd() *cobra.Command {
3535
}
3636

3737
func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
38-
response, progressShown, err := desktopClient.Push(model, TUIProgress)
38+
printer := desktop.NewCobraPrinter()
39+
response, progressShown, err := desktopClient.Push(model, printer)
3940

4041
// Add a newline before any output (success or error) if progress was shown.
4142
if progressShown {

cmd/cli/desktop/desktop.go

Lines changed: 12 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@ import (
66
"context"
77
"encoding/json"
88
"fmt"
9-
"html"
109
"io"
1110
"net/http"
1211
"net/url"
1312
"strconv"
1413
"strings"
1514
"time"
1615

17-
"github.com/docker/go-units"
1816
"github.com/docker/model-runner/pkg/distribution/distribution"
1917
"github.com/docker/model-runner/pkg/inference"
2018
dmrm "github.com/docker/model-runner/pkg/inference/models"
@@ -105,7 +103,7 @@ func (c *Client) Status() Status {
105103
}
106104
}
107105

108-
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) {
106+
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer StatusPrinter) (string, bool, error) {
109107
model = normalizeHuggingFaceModelName(model)
110108
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
111109
if err != nil {
@@ -128,52 +126,16 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
128126
return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
129127
}
130128

131-
progressShown := false
132-
current := uint64(0) // Track cumulative progress across all layers
133-
layerProgress := make(map[string]uint64) // Track progress per layer ID
134-
135-
scanner := bufio.NewScanner(resp.Body)
136-
for scanner.Scan() {
137-
progressLine := scanner.Text()
138-
if progressLine == "" {
139-
continue
140-
}
141-
142-
// Parse the progress message
143-
var progressMsg ProgressMessage
144-
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
145-
return "", progressShown, fmt.Errorf("error parsing progress message: %w", err)
146-
}
147-
148-
// Handle different message types
149-
switch progressMsg.Type {
150-
case "progress":
151-
// Update the current progress for this layer
152-
layerID := progressMsg.Layer.ID
153-
layerProgress[layerID] = progressMsg.Layer.Current
154-
155-
// Sum all layer progress values
156-
current = uint64(0)
157-
for _, layerCurrent := range layerProgress {
158-
current += layerCurrent
159-
}
160-
161-
progress(fmt.Sprintf("Downloaded %s of %s", units.CustomSize("%.2f%s", float64(current), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}), units.CustomSize("%.2f%s", float64(progressMsg.Total), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"})))
162-
progressShown = true
163-
case "error":
164-
return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message)
165-
case "success":
166-
return progressMsg.Message, progressShown, nil
167-
default:
168-
return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type)
169-
}
129+
// Use Docker-style progress display
130+
message, err := DisplayProgress(resp.Body, printer)
131+
if err != nil {
132+
return "", true, err
170133
}
171134

172-
// If we get here, something went wrong
173-
return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model)
135+
return message, true, nil
174136
}
175137

176-
func (c *Client) Push(model string, progress func(string)) (string, bool, error) {
138+
func (c *Client) Push(model string, printer StatusPrinter) (string, bool, error) {
177139
model = normalizeHuggingFaceModelName(model)
178140
pushPath := inference.ModelsPrefix + "/" + model + "/push"
179141
resp, err := c.doRequest(
@@ -191,37 +153,13 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error)
191153
return "", false, fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body))
192154
}
193155

194-
progressShown := false
195-
196-
scanner := bufio.NewScanner(resp.Body)
197-
for scanner.Scan() {
198-
progressLine := scanner.Text()
199-
if progressLine == "" {
200-
continue
201-
}
202-
203-
// Parse the progress message
204-
var progressMsg ProgressMessage
205-
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
206-
return "", progressShown, fmt.Errorf("error parsing progress message: %w", err)
207-
}
208-
209-
// Handle different message types
210-
switch progressMsg.Type {
211-
case "progress":
212-
progress(progressMsg.Message)
213-
progressShown = true
214-
case "error":
215-
return "", progressShown, fmt.Errorf("error pushing model: %s", progressMsg.Message)
216-
case "success":
217-
return progressMsg.Message, progressShown, nil
218-
default:
219-
return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type)
220-
}
156+
// Use Docker-style progress display
157+
message, err := DisplayProgress(resp.Body, printer)
158+
if err != nil {
159+
return "", true, err
221160
}
222161

223-
// If we get here, something went wrong
224-
return "", progressShown, fmt.Errorf("unexpected end of stream while pushing model %s", model)
162+
return message, true, nil
225163
}
226164

227165
func (c *Client) List() ([]dmrm.Model, error) {

cmd/cli/desktop/desktop_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ func TestPullHuggingFaceModel(t *testing.T) {
3636
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
3737
}, nil)
3838

39-
_, _, err := client.Pull(modelName, false, func(s string) {})
39+
printer := NewSimplePrinter(func(s string) {})
40+
_, _, err := client.Pull(modelName, false, printer)
4041
assert.NoError(t, err)
4142
}
4243

@@ -122,7 +123,8 @@ func TestNonHuggingFaceModel(t *testing.T) {
122123
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
123124
}, nil)
124125

125-
_, _, err := client.Pull(modelName, false, func(s string) {})
126+
printer := NewSimplePrinter(func(s string) {})
127+
_, _, err := client.Pull(modelName, false, printer)
126128
assert.NoError(t, err)
127129
}
128130

@@ -145,7 +147,8 @@ func TestPushHuggingFaceModel(t *testing.T) {
145147
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)),
146148
}, nil)
147149

148-
_, _, err := client.Push(modelName, func(s string) {})
150+
printer := NewSimplePrinter(func(s string) {})
151+
_, _, err := client.Push(modelName, printer)
149152
assert.NoError(t, err)
150153
}
151154

0 commit comments

Comments
 (0)