@@ -6,15 +6,13 @@ import (
66 "context"
77 "encoding/json"
88 "fmt"
9- "html"
109 "io"
1110 "net/http"
1211 "net/url"
1312 "strconv"
1413 "strings"
1514 "time"
1615
17- "github.com/docker/go-units"
1816 "github.com/docker/model-runner/pkg/distribution/distribution"
1917 "github.com/docker/model-runner/pkg/inference"
2018 dmrm "github.com/docker/model-runner/pkg/inference/models"
@@ -105,7 +103,7 @@ func (c *Client) Status() Status {
105103 }
106104}
107105
108- func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , progress func ( string ) ) (string , bool , error ) {
106+ func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , printer StatusPrinter ) (string , bool , error ) {
109107 model = normalizeHuggingFaceModelName (model )
110108 jsonData , err := json .Marshal (dmrm.ModelCreateRequest {From : model , IgnoreRuntimeMemoryCheck : ignoreRuntimeMemoryCheck })
111109 if err != nil {
@@ -128,52 +126,16 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
128126 return "" , false , fmt .Errorf ("pulling %s failed with status %s: %s" , model , resp .Status , string (body ))
129127 }
130128
131- progressShown := false
132- current := uint64 (0 ) // Track cumulative progress across all layers
133- layerProgress := make (map [string ]uint64 ) // Track progress per layer ID
134-
135- scanner := bufio .NewScanner (resp .Body )
136- for scanner .Scan () {
137- progressLine := scanner .Text ()
138- if progressLine == "" {
139- continue
140- }
141-
142- // Parse the progress message
143- var progressMsg ProgressMessage
144- if err := json .Unmarshal ([]byte (html .UnescapeString (progressLine )), & progressMsg ); err != nil {
145- return "" , progressShown , fmt .Errorf ("error parsing progress message: %w" , err )
146- }
147-
148- // Handle different message types
149- switch progressMsg .Type {
150- case "progress" :
151- // Update the current progress for this layer
152- layerID := progressMsg .Layer .ID
153- layerProgress [layerID ] = progressMsg .Layer .Current
154-
155- // Sum all layer progress values
156- current = uint64 (0 )
157- for _ , layerCurrent := range layerProgress {
158- current += layerCurrent
159- }
160-
161- progress (fmt .Sprintf ("Downloaded %s of %s" , units .CustomSize ("%.2f%s" , float64 (current ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" }), units .CustomSize ("%.2f%s" , float64 (progressMsg .Total ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })))
162- progressShown = true
163- case "error" :
164- return "" , progressShown , fmt .Errorf ("error pulling model: %s" , progressMsg .Message )
165- case "success" :
166- return progressMsg .Message , progressShown , nil
167- default :
168- return "" , progressShown , fmt .Errorf ("unknown message type: %s" , progressMsg .Type )
169- }
129+ // Use Docker-style progress display
130+ message , err := DisplayProgress (resp .Body , printer )
131+ if err != nil {
132+ return "" , true , err
170133 }
171134
172- // If we get here, something went wrong
173- return "" , progressShown , fmt .Errorf ("unexpected end of stream while pulling model %s" , model )
135+ return message , true , nil
174136}
175137
176- func (c * Client ) Push (model string , progress func ( string ) ) (string , bool , error ) {
138+ func (c * Client ) Push (model string , printer StatusPrinter ) (string , bool , error ) {
177139 model = normalizeHuggingFaceModelName (model )
178140 pushPath := inference .ModelsPrefix + "/" + model + "/push"
179141 resp , err := c .doRequest (
@@ -191,37 +153,13 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error)
191153 return "" , false , fmt .Errorf ("pushing %s failed with status %s: %s" , model , resp .Status , string (body ))
192154 }
193155
194- progressShown := false
195-
196- scanner := bufio .NewScanner (resp .Body )
197- for scanner .Scan () {
198- progressLine := scanner .Text ()
199- if progressLine == "" {
200- continue
201- }
202-
203- // Parse the progress message
204- var progressMsg ProgressMessage
205- if err := json .Unmarshal ([]byte (html .UnescapeString (progressLine )), & progressMsg ); err != nil {
206- return "" , progressShown , fmt .Errorf ("error parsing progress message: %w" , err )
207- }
208-
209- // Handle different message types
210- switch progressMsg .Type {
211- case "progress" :
212- progress (progressMsg .Message )
213- progressShown = true
214- case "error" :
215- return "" , progressShown , fmt .Errorf ("error pushing model: %s" , progressMsg .Message )
216- case "success" :
217- return progressMsg .Message , progressShown , nil
218- default :
219- return "" , progressShown , fmt .Errorf ("unknown message type: %s" , progressMsg .Type )
220- }
156+ // Use Docker-style progress display
157+ message , err := DisplayProgress (resp .Body , printer )
158+ if err != nil {
159+ return "" , true , err
221160 }
222161
223- // If we get here, something went wrong
224- return "" , progressShown , fmt .Errorf ("unexpected end of stream while pushing model %s" , model )
162+ return message , true , nil
225163}
226164
227165func (c * Client ) List () ([]dmrm.Model , error ) {
0 commit comments