Skip to content

Commit 87106d8

Browse files
committed
Add progress bar
And remove existing transport implementation We are gonna implement this in the go-containerregistry layer it's cleaner Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 7fdb650 commit 87106d8

File tree

28 files changed

+293
-6469
lines changed

28 files changed

+293
-6469
lines changed

cmd/cli/commands/compose.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string
175175
}
176176
return false
177177
}) {
178-
_, _, err = desktopClient.Pull(model, false, func(s string) {
178+
printer := desktop.NewSimplePrinter(func(s string) {
179179
_ = sendInfo(s)
180180
})
181+
_, _, err = desktopClient.Pull(model, false, printer)
181182
if err != nil {
182183
_ = sendErrorf("Failed to pull model: %v", err)
183184
return fmt.Errorf("Failed to pull model: %v\n", err)

cmd/cli/commands/integration_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,9 @@ func TestIntegration_PushModel(t *testing.T) {
700700

701701
// Push the tagged model
702702
t.Logf("Pushing model to custom registry with reference: %s", tc.ref)
703-
_, _, err = env.client.Push(tc.ref, func(msg string) {
703+
_, _, err = env.client.Push(tc.ref, desktop.NewSimplePrinter(func(msg string) {
704704
t.Logf("Progress: %s", msg)
705-
})
705+
}))
706706
require.NoError(t, err, "Failed to push model to custom registry")
707707
t.Logf("✓ Successfully pushed model to custom registry: %s", tc.ref)
708708
})
@@ -742,9 +742,9 @@ func TestIntegration_PushModel(t *testing.T) {
742742

743743
// Push the tagged model
744744
t.Logf("Pushing model to custom registry with reference: %s", tc.targetRef)
745-
_, _, err = env.client.Push(tc.targetRef, func(msg string) {
745+
_, _, err = env.client.Push(tc.targetRef, desktop.NewSimplePrinter(func(msg string) {
746746
t.Logf("Progress: %s", msg)
747-
})
747+
}))
748748
require.NoError(t, err, "Failed to push model to custom registry")
749749
t.Logf("✓ Successfully pushed model to custom registry: %s", tc.targetRef)
750750
})
@@ -754,13 +754,13 @@ func TestIntegration_PushModel(t *testing.T) {
754754
// Test 3: Error cases
755755
t.Run("error cases", func(t *testing.T) {
756756
t.Run("push non-existent model", func(t *testing.T) {
757-
_, _, err := env.client.Push("non-existent-model:v1", func(msg string) {})
757+
_, _, err := env.client.Push("non-existent-model:v1", desktop.NewSimplePrinter(func(msg string) {}))
758758
require.Error(t, err, "Should fail when pushing non-existent model")
759759
t.Logf("✓ Correctly failed to push non-existent model: %v", err)
760760
})
761761

762762
t.Run("push with invalid reference", func(t *testing.T) {
763-
_, _, err := env.client.Push("", func(msg string) {})
763+
_, _, err := env.client.Push("", desktop.NewSimplePrinter(func(msg string) {}))
764764
require.Error(t, err, "Should fail with empty reference")
765765
t.Logf("✓ Correctly failed to push with invalid reference: %v", err)
766766
})

cmd/cli/commands/package.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error {
379379
}
380380

381381
// Print progress messages
382-
TUIProgress(progressMsg.Message)
382+
fmt.Print("\r\033[K", progressMsg.Message)
383383
}
384384
cmd.PrintErrln("") // newline after progress
385385

cmd/cli/commands/pull.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@ package commands
22

33
import (
44
"fmt"
5-
"os"
65

76
"github.com/docker/model-runner/cmd/cli/commands/completion"
87
"github.com/docker/model-runner/cmd/cli/desktop"
98

10-
"github.com/mattn/go-isatty"
119
"github.com/spf13/cobra"
1210
)
1311

@@ -42,13 +40,8 @@ func newPullCmd() *cobra.Command {
4240
}
4341

4442
func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error {
45-
var progress func(string)
46-
if isatty.IsTerminal(os.Stdout.Fd()) {
47-
progress = TUIProgress
48-
} else {
49-
progress = RawProgress
50-
}
51-
response, progressShown, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, progress)
43+
printer := asPrinter(cmd)
44+
response, progressShown, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, printer)
5245

5346
// Add a newline before any output (success or error) if progress was shown.
5447
if progressShown {
@@ -62,11 +55,3 @@ func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string,
6255
cmd.Println(response)
6356
return nil
6457
}
65-
66-
func TUIProgress(message string) {
67-
fmt.Print("\r\033[K", message)
68-
}
69-
70-
func RawProgress(message string) {
71-
fmt.Println(message)
72-
}

cmd/cli/commands/push.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ func newPushCmd() *cobra.Command {
3535
}
3636

3737
func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
38-
response, progressShown, err := desktopClient.Push(model, TUIProgress)
38+
printer := asPrinter(cmd)
39+
response, progressShown, err := desktopClient.Push(model, printer)
3940

4041
// Add a newline before any output (success or error) if progress was shown.
4142
if progressShown {

cmd/cli/desktop/desktop.go

Lines changed: 13 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@ import (
66
"context"
77
"encoding/json"
88
"fmt"
9-
"html"
109
"io"
1110
"net/http"
1211
"net/url"
1312
"strconv"
1413
"strings"
1514
"time"
1615

17-
"github.com/docker/go-units"
16+
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
1817
"github.com/docker/model-runner/pkg/distribution/distribution"
1918
"github.com/docker/model-runner/pkg/inference"
2019
dmrm "github.com/docker/model-runner/pkg/inference/models"
@@ -105,7 +104,7 @@ func (c *Client) Status() Status {
105104
}
106105
}
107106

108-
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) {
107+
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer standalone.StatusPrinter) (string, bool, error) {
109108
model = normalizeHuggingFaceModelName(model)
110109
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
111110
if err != nil {
@@ -128,52 +127,16 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
128127
return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
129128
}
130129

131-
progressShown := false
132-
current := uint64(0) // Track cumulative progress across all layers
133-
layerProgress := make(map[string]uint64) // Track progress per layer ID
134-
135-
scanner := bufio.NewScanner(resp.Body)
136-
for scanner.Scan() {
137-
progressLine := scanner.Text()
138-
if progressLine == "" {
139-
continue
140-
}
141-
142-
// Parse the progress message
143-
var progressMsg ProgressMessage
144-
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
145-
return "", progressShown, fmt.Errorf("error parsing progress message: %w", err)
146-
}
147-
148-
// Handle different message types
149-
switch progressMsg.Type {
150-
case "progress":
151-
// Update the current progress for this layer
152-
layerID := progressMsg.Layer.ID
153-
layerProgress[layerID] = progressMsg.Layer.Current
154-
155-
// Sum all layer progress values
156-
current = uint64(0)
157-
for _, layerCurrent := range layerProgress {
158-
current += layerCurrent
159-
}
160-
161-
progress(fmt.Sprintf("Downloaded %s of %s", units.CustomSize("%.2f%s", float64(current), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}), units.CustomSize("%.2f%s", float64(progressMsg.Total), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"})))
162-
progressShown = true
163-
case "error":
164-
return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message)
165-
case "success":
166-
return progressMsg.Message, progressShown, nil
167-
default:
168-
return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type)
169-
}
130+
// Use Docker-style progress display
131+
message, err := DisplayProgress(resp.Body, printer)
132+
if err != nil {
133+
return "", true, err
170134
}
171135

172-
// If we get here, something went wrong
173-
return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model)
136+
return message, true, nil
174137
}
175138

176-
func (c *Client) Push(model string, progress func(string)) (string, bool, error) {
139+
func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {
177140
model = normalizeHuggingFaceModelName(model)
178141
pushPath := inference.ModelsPrefix + "/" + model + "/push"
179142
resp, err := c.doRequest(
@@ -191,37 +154,13 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error)
191154
return "", false, fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body))
192155
}
193156

194-
progressShown := false
195-
196-
scanner := bufio.NewScanner(resp.Body)
197-
for scanner.Scan() {
198-
progressLine := scanner.Text()
199-
if progressLine == "" {
200-
continue
201-
}
202-
203-
// Parse the progress message
204-
var progressMsg ProgressMessage
205-
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
206-
return "", progressShown, fmt.Errorf("error parsing progress message: %w", err)
207-
}
208-
209-
// Handle different message types
210-
switch progressMsg.Type {
211-
case "progress":
212-
progress(progressMsg.Message)
213-
progressShown = true
214-
case "error":
215-
return "", progressShown, fmt.Errorf("error pushing model: %s", progressMsg.Message)
216-
case "success":
217-
return progressMsg.Message, progressShown, nil
218-
default:
219-
return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type)
220-
}
157+
// Use Docker-style progress display
158+
message, err := DisplayProgress(resp.Body, printer)
159+
if err != nil {
160+
return "", true, err
221161
}
222162

223-
// If we get here, something went wrong
224-
return "", progressShown, fmt.Errorf("unexpected end of stream while pushing model %s", model)
163+
return message, true, nil
225164
}
226165

227166
func (c *Client) List() ([]dmrm.Model, error) {

cmd/cli/desktop/desktop_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ func TestPullHuggingFaceModel(t *testing.T) {
3636
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
3737
}, nil)
3838

39-
_, _, err := client.Pull(modelName, false, func(s string) {})
39+
printer := NewSimplePrinter(func(s string) {})
40+
_, _, err := client.Pull(modelName, false, printer)
4041
assert.NoError(t, err)
4142
}
4243

@@ -122,7 +123,8 @@ func TestNonHuggingFaceModel(t *testing.T) {
122123
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
123124
}, nil)
124125

125-
_, _, err := client.Pull(modelName, false, func(s string) {})
126+
printer := NewSimplePrinter(func(s string) {})
127+
_, _, err := client.Pull(modelName, false, printer)
126128
assert.NoError(t, err)
127129
}
128130

@@ -145,7 +147,8 @@ func TestPushHuggingFaceModel(t *testing.T) {
145147
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)),
146148
}, nil)
147149

148-
_, _, err := client.Push(modelName, func(s string) {})
150+
printer := NewSimplePrinter(func(s string) {})
151+
_, _, err := client.Push(modelName, printer)
149152
assert.NoError(t, err)
150153
}
151154

0 commit comments

Comments
 (0)