@@ -6,15 +6,14 @@ import (
66 "context"
77 "encoding/json"
88 "fmt"
9- "html"
109 "io"
1110 "net/http"
1211 "net/url"
1312 "strconv"
1413 "strings"
1514 "time"
1615
17- "github.com/docker/go-units "
16+ "github.com/docker/model-runner/cmd/cli/pkg/standalone "
1817 "github.com/docker/model-runner/pkg/distribution/distribution"
1918 "github.com/docker/model-runner/pkg/inference"
2019 dmrm "github.com/docker/model-runner/pkg/inference/models"
@@ -105,7 +104,7 @@ func (c *Client) Status() Status {
105104 }
106105}
107106
108- func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , progress func ( string ) ) (string , bool , error ) {
107+ func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , printer standalone. StatusPrinter ) (string , bool , error ) {
109108 model = normalizeHuggingFaceModelName (model )
110109 jsonData , err := json .Marshal (dmrm.ModelCreateRequest {From : model , IgnoreRuntimeMemoryCheck : ignoreRuntimeMemoryCheck })
111110 if err != nil {
@@ -128,52 +127,16 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
128127 return "" , false , fmt .Errorf ("pulling %s failed with status %s: %s" , model , resp .Status , string (body ))
129128 }
130129
131- progressShown := false
132- current := uint64 (0 ) // Track cumulative progress across all layers
133- layerProgress := make (map [string ]uint64 ) // Track progress per layer ID
134-
135- scanner := bufio .NewScanner (resp .Body )
136- for scanner .Scan () {
137- progressLine := scanner .Text ()
138- if progressLine == "" {
139- continue
140- }
141-
142- // Parse the progress message
143- var progressMsg ProgressMessage
144- if err := json .Unmarshal ([]byte (html .UnescapeString (progressLine )), & progressMsg ); err != nil {
145- return "" , progressShown , fmt .Errorf ("error parsing progress message: %w" , err )
146- }
147-
148- // Handle different message types
149- switch progressMsg .Type {
150- case "progress" :
151- // Update the current progress for this layer
152- layerID := progressMsg .Layer .ID
153- layerProgress [layerID ] = progressMsg .Layer .Current
154-
155- // Sum all layer progress values
156- current = uint64 (0 )
157- for _ , layerCurrent := range layerProgress {
158- current += layerCurrent
159- }
160-
161- progress (fmt .Sprintf ("Downloaded %s of %s" , units .CustomSize ("%.2f%s" , float64 (current ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" }), units .CustomSize ("%.2f%s" , float64 (progressMsg .Total ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })))
162- progressShown = true
163- case "error" :
164- return "" , progressShown , fmt .Errorf ("error pulling model: %s" , progressMsg .Message )
165- case "success" :
166- return progressMsg .Message , progressShown , nil
167- default :
168- return "" , progressShown , fmt .Errorf ("unknown message type: %s" , progressMsg .Type )
169- }
130+ // Use Docker-style progress display
131+ message , err := DisplayProgress (resp .Body , printer )
132+ if err != nil {
133+ return "" , true , err
170134 }
171135
172- // If we get here, something went wrong
173- return "" , progressShown , fmt .Errorf ("unexpected end of stream while pulling model %s" , model )
136+ return message , true , nil
174137}
175138
176- func (c * Client ) Push (model string , progress func ( string ) ) (string , bool , error ) {
139+ func (c * Client ) Push (model string , printer standalone. StatusPrinter ) (string , bool , error ) {
177140 model = normalizeHuggingFaceModelName (model )
178141 pushPath := inference .ModelsPrefix + "/" + model + "/push"
179142 resp , err := c .doRequest (
@@ -191,37 +154,13 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error)
191154 return "" , false , fmt .Errorf ("pushing %s failed with status %s: %s" , model , resp .Status , string (body ))
192155 }
193156
194- progressShown := false
195-
196- scanner := bufio .NewScanner (resp .Body )
197- for scanner .Scan () {
198- progressLine := scanner .Text ()
199- if progressLine == "" {
200- continue
201- }
202-
203- // Parse the progress message
204- var progressMsg ProgressMessage
205- if err := json .Unmarshal ([]byte (html .UnescapeString (progressLine )), & progressMsg ); err != nil {
206- return "" , progressShown , fmt .Errorf ("error parsing progress message: %w" , err )
207- }
208-
209- // Handle different message types
210- switch progressMsg .Type {
211- case "progress" :
212- progress (progressMsg .Message )
213- progressShown = true
214- case "error" :
215- return "" , progressShown , fmt .Errorf ("error pushing model: %s" , progressMsg .Message )
216- case "success" :
217- return progressMsg .Message , progressShown , nil
218- default :
219- return "" , progressShown , fmt .Errorf ("unknown message type: %s" , progressMsg .Type )
220- }
157+ // Use Docker-style progress display
158+ message , err := DisplayProgress (resp .Body , printer )
159+ if err != nil {
160+ return "" , true , err
221161 }
222162
223- // If we get here, something went wrong
224- return "" , progressShown , fmt .Errorf ("unexpected end of stream while pushing model %s" , model )
163+ return message , true , nil
225164}
226165
227166func (c * Client ) List () ([]dmrm.Model , error ) {
0 commit comments