Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,10 @@ func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string
}
return false
}) {
_, _, err = desktopClient.Pull(model, false, func(s string) {
printer := desktop.NewSimplePrinter(func(s string) {
_ = sendInfo(s)
})
_, _, err = desktopClient.Pull(model, false, printer)
if err != nil {
_ = sendErrorf("Failed to pull model: %v", err)
return fmt.Errorf("Failed to pull model: %v\n", err)
Expand Down
12 changes: 6 additions & 6 deletions cmd/cli/commands/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,9 @@ func TestIntegration_PushModel(t *testing.T) {

// Push the tagged model
t.Logf("Pushing model to custom registry with reference: %s", tc.ref)
_, _, err = env.client.Push(tc.ref, func(msg string) {
_, _, err = env.client.Push(tc.ref, desktop.NewSimplePrinter(func(msg string) {
t.Logf("Progress: %s", msg)
})
}))
require.NoError(t, err, "Failed to push model to custom registry")
t.Logf("✓ Successfully pushed model to custom registry: %s", tc.ref)
})
Expand Down Expand Up @@ -742,9 +742,9 @@ func TestIntegration_PushModel(t *testing.T) {

// Push the tagged model
t.Logf("Pushing model to custom registry with reference: %s", tc.targetRef)
_, _, err = env.client.Push(tc.targetRef, func(msg string) {
_, _, err = env.client.Push(tc.targetRef, desktop.NewSimplePrinter(func(msg string) {
t.Logf("Progress: %s", msg)
})
}))
require.NoError(t, err, "Failed to push model to custom registry")
t.Logf("✓ Successfully pushed model to custom registry: %s", tc.targetRef)
})
Expand All @@ -754,13 +754,13 @@ func TestIntegration_PushModel(t *testing.T) {
// Test 3: Error cases
t.Run("error cases", func(t *testing.T) {
t.Run("push non-existent model", func(t *testing.T) {
_, _, err := env.client.Push("non-existent-model:v1", func(msg string) {})
_, _, err := env.client.Push("non-existent-model:v1", desktop.NewSimplePrinter(func(msg string) {}))
require.Error(t, err, "Should fail when pushing non-existent model")
t.Logf("✓ Correctly failed to push non-existent model: %v", err)
})

t.Run("push with invalid reference", func(t *testing.T) {
_, _, err := env.client.Push("", func(msg string) {})
_, _, err := env.client.Push("", desktop.NewSimplePrinter(func(msg string) {}))
require.Error(t, err, "Should fail with empty reference")
t.Logf("✓ Correctly failed to push with invalid reference: %v", err)
})
Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/commands/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error {
}

// Print progress messages
TUIProgress(progressMsg.Message)
fmt.Print("\r\033[K", progressMsg.Message)
}
cmd.PrintErrln("") // newline after progress

Expand Down
24 changes: 2 additions & 22 deletions cmd/cli/commands/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package commands

import (
"fmt"
"os"

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/cmd/cli/desktop"

"github.com/mattn/go-isatty"
"github.com/spf13/cobra"
)

Expand All @@ -33,18 +31,8 @@ func newPullCmd() *cobra.Command {
}

func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error {
var progress func(string)
if isatty.IsTerminal(os.Stdout.Fd()) {
progress = TUIProgress
} else {
progress = RawProgress
}
response, progressShown, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, progress)

// Add a newline before any output (success or error) if progress was shown.
if progressShown {
cmd.Println()
}
printer := asPrinter(cmd)
response, _, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, printer)

if err != nil {
return handleClientError(err, "Failed to pull model")
Expand All @@ -53,11 +41,3 @@ func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string,
cmd.Println(response)
return nil
}

func TUIProgress(message string) {
fmt.Print("\r\033[K", message)
}

func RawProgress(message string) {
fmt.Println(message)
}
3 changes: 2 additions & 1 deletion cmd/cli/commands/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ func newPushCmd() *cobra.Command {
}

func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
response, progressShown, err := desktopClient.Push(model, TUIProgress)
printer := asPrinter(cmd)
response, progressShown, err := desktopClient.Push(model, printer)

// Add a newline before any output (success or error) if progress was shown.
if progressShown {
Expand Down
87 changes: 13 additions & 74 deletions cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@ import (
"context"
"encoding/json"
"fmt"
"html"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/docker/go-units"
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/docker/model-runner/pkg/inference"
dmrm "github.com/docker/model-runner/pkg/inference/models"
Expand Down Expand Up @@ -105,7 +104,7 @@ func (c *Client) Status() Status {
}
}

func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) {
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer standalone.StatusPrinter) (string, bool, error) {
model = normalizeHuggingFaceModelName(model)
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
if err != nil {
Expand All @@ -128,52 +127,16 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
}

progressShown := false
current := uint64(0) // Track cumulative progress across all layers
layerProgress := make(map[string]uint64) // Track progress per layer ID

scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
progressLine := scanner.Text()
if progressLine == "" {
continue
}

// Parse the progress message
var progressMsg ProgressMessage
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
return "", progressShown, fmt.Errorf("error parsing progress message: %w", err)
}

// Handle different message types
switch progressMsg.Type {
case "progress":
// Update the current progress for this layer
layerID := progressMsg.Layer.ID
layerProgress[layerID] = progressMsg.Layer.Current

// Sum all layer progress values
current = uint64(0)
for _, layerCurrent := range layerProgress {
current += layerCurrent
}

progress(fmt.Sprintf("Downloaded %s of %s", units.CustomSize("%.2f%s", float64(current), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}), units.CustomSize("%.2f%s", float64(progressMsg.Total), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"})))
progressShown = true
case "error":
return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message)
case "success":
return progressMsg.Message, progressShown, nil
default:
return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type)
}
// Use Docker-style progress display
message, progressShown, err := DisplayProgress(resp.Body, printer)
if err != nil {
return "", progressShown, err
}

// If we get here, something went wrong
return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model)
return message, progressShown, nil
}

func (c *Client) Push(model string, progress func(string)) (string, bool, error) {
func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {
model = normalizeHuggingFaceModelName(model)
pushPath := inference.ModelsPrefix + "/" + model + "/push"
resp, err := c.doRequest(
Expand All @@ -191,37 +154,13 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error)
return "", false, fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body))
}

progressShown := false

scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
progressLine := scanner.Text()
if progressLine == "" {
continue
}

// Parse the progress message
var progressMsg ProgressMessage
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
return "", progressShown, fmt.Errorf("error parsing progress message: %w", err)
}

// Handle different message types
switch progressMsg.Type {
case "progress":
progress(progressMsg.Message)
progressShown = true
case "error":
return "", progressShown, fmt.Errorf("error pushing model: %s", progressMsg.Message)
case "success":
return progressMsg.Message, progressShown, nil
default:
return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type)
}
// Use Docker-style progress display
message, progressShown, err := DisplayProgress(resp.Body, printer)
if err != nil {
return "", progressShown, err
}

// If we get here, something went wrong
return "", progressShown, fmt.Errorf("unexpected end of stream while pushing model %s", model)
return message, progressShown, nil
}

func (c *Client) List() ([]dmrm.Model, error) {
Expand Down
9 changes: 6 additions & 3 deletions cmd/cli/desktop/desktop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func TestPullHuggingFaceModel(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
}, nil)

_, _, err := client.Pull(modelName, false, func(s string) {})
printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull(modelName, false, printer)
assert.NoError(t, err)
}

Expand Down Expand Up @@ -122,7 +123,8 @@ func TestNonHuggingFaceModel(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
}, nil)

_, _, err := client.Pull(modelName, false, func(s string) {})
printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull(modelName, false, printer)
assert.NoError(t, err)
}

Expand All @@ -145,7 +147,8 @@ func TestPushHuggingFaceModel(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)),
}, nil)

_, _, err := client.Push(modelName, func(s string) {})
printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Push(modelName, printer)
assert.NoError(t, err)
}

Expand Down
Loading
Loading