@@ -9,9 +9,13 @@ import (
99 "html"
1010 "io"
1111 "net/http"
12+ "os"
13+ "runtime"
1214 "strconv"
1315 "strings"
16+ "syscall"
1417 "time"
18+ "unsafe"
1519
1620 "github.com/docker/go-units"
1721 "github.com/docker/model-distribution/distribution"
@@ -106,6 +110,147 @@ func (c *Client) Status() Status {
106110 }
107111}
108112
113+ func humanReadableSize (size float64 ) string {
114+ return units .CustomSize ("%.2f%s" , float64 (size ), 1000.0 , []string {"B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })
115+ }
116+
117+ func humanReadableSizePad (size float64 , width int ) string {
118+ return fmt .Sprintf ("%*s" , width , humanReadableSize (size ))
119+ }
120+
121+ func humanReadableTimePad (seconds int64 , width int ) string {
122+ var s string
123+ if seconds < 60 {
124+ s = fmt .Sprintf ("%ds" , seconds )
125+ } else if seconds < 3600 {
126+ s = fmt .Sprintf ("%dm %02ds" , seconds / 60 , seconds % 60 )
127+ } else {
128+ s = fmt .Sprintf ("%dh %02dm %02ds" , seconds / 3600 , (seconds % 3600 )/ 60 , seconds % 60 )
129+ }
130+ return fmt .Sprintf ("%*s" , width , s )
131+ }
132+
133+ // ProgressBarState tracks the running totals and timing for speed/ETA
134+ type ProgressBarState struct {
135+ LastDownloaded uint64
136+ LastTime time.Time
137+ StartTime time.Time
138+ UpdateInterval time.Duration // New: interval between updates
139+ lastPrint time.Time // New: last time the progress bar was printed
140+ }
141+
142+ // formatBar calculates the bar width and filled bar string.
143+ func (pbs * ProgressBarState ) formatBar (percent float64 , termWidth int , prefix , suffix string ) string {
144+ barWidth := termWidth - len (prefix ) - len (suffix ) - 4
145+ if barWidth < 10 {
146+ barWidth = 10
147+ }
148+ filled := int (percent / 100 * float64 (barWidth ))
149+ if filled > barWidth {
150+ filled = barWidth
151+ }
152+ bar := strings .Repeat ("█" , filled ) + strings .Repeat (" " , barWidth - filled )
153+ return bar
154+ }
155+
156+ // calcSpeed calculates the current download speed.
157+ func (pbs * ProgressBarState ) calcSpeed (current uint64 , now time.Time ) float64 {
158+ elapsed := now .Sub (pbs .LastTime ).Seconds ()
159+ if elapsed <= 0 {
160+ return 0
161+ }
162+
163+ speed := float64 (current - pbs .LastDownloaded ) / elapsed
164+ pbs .LastTime = now
165+ pbs .LastDownloaded = current
166+
167+ return speed
168+ }
169+
170+ // formatSuffix returns the suffix string showing human readable sizes, speed, and ETA.
171+ func (pbs * ProgressBarState ) fmtSuffix (current , total uint64 , speed float64 , eta int64 ) string {
172+ return fmt .Sprintf ("%s/%s %s/s %s" ,
173+ humanReadableSizePad (float64 (current ), 10 ),
174+ humanReadableSize (float64 (total )),
175+ humanReadableSizePad (speed , 10 ),
176+ humanReadableTimePad (eta , 16 ),
177+ )
178+ }
179+
180+ // calcETA calculates the estimated time remaining.
181+ func (pbs * ProgressBarState ) calcETA (current , total uint64 , speed float64 ) int64 {
182+ if speed <= 0 {
183+ return 0
184+ }
185+ return int64 (float64 (total - current ) / speed )
186+ }
187+
188+ // printProgressBar prints/updates a progress bar in the terminal
189+ // Only prints if UpdateInterval has passed since last print, or always if interval=0
190+ func (pbs * ProgressBarState ) printProgressBar (current , total uint64 ) {
191+ if pbs .StartTime .IsZero () {
192+ pbs .StartTime = time .Now ()
193+ pbs .LastTime = pbs .StartTime
194+ pbs .LastDownloaded = current
195+ pbs .lastPrint = pbs .StartTime
196+ }
197+
198+ now := time .Now ()
199+ // Only update display if enough time passed,
200+ // unless interval is 0 (always print)
201+ if pbs .UpdateInterval > 0 && now .Sub (pbs .lastPrint ) < pbs .UpdateInterval && current != total {
202+ return
203+ }
204+
205+ pbs .lastPrint = now
206+ termWidth := getTerminalWidth ()
207+ percent := float64 (current ) / float64 (total ) * 100
208+ prefix := fmt .Sprintf ("%3.0f%% |" , percent )
209+ speed := pbs .calcSpeed (current , now )
210+ eta := pbs .calcETA (current , total , speed )
211+ suffix := pbs .fmtSuffix (current , total , speed , eta )
212+ bar := pbs .formatBar (percent , termWidth , prefix , suffix )
213+ fmt .Fprintf (os .Stderr , "\r %s%s| %s" , prefix , bar , suffix )
214+ }
215+
216+ func getTerminalWidthUnix () (int , error ) {
217+ type winsize struct {
218+ Row uint16
219+ Col uint16
220+ Xpixel uint16
221+ Ypixel uint16
222+ }
223+ ws := & winsize {}
224+ retCode , _ , errno := syscall .Syscall6 (
225+ syscall .SYS_IOCTL ,
226+ uintptr (os .Stdout .Fd ()),
227+ uintptr (syscall .TIOCGWINSZ ),
228+ uintptr (unsafe .Pointer (ws )),
229+ 0 , 0 , 0 ,
230+ )
231+ if int (retCode ) == - 1 {
232+ return 0 , errno
233+ }
234+ return int (ws .Col ), nil
235+ }
236+
237+ // getTerminalSize tries to get the terminal width (default 80 if fails)
238+ func getTerminalWidth () int {
239+ var width int
240+ var err error
241+ default_width := 80
242+ if runtime .GOOS == "windows" { // to be implemented
243+ return default_width
244+ }
245+
246+ width , err = getTerminalWidthUnix ()
247+ if width == 0 || err != nil {
248+ return default_width
249+ }
250+
251+ return width
252+ }
253+
109254func (c * Client ) Pull (model string , ignoreRuntimeMemoryCheck bool , progress func (string )) (string , bool , error ) {
110255 model = normalizeHuggingFaceModelName (model )
111256 jsonData , err := json .Marshal (dmrm.ModelCreateRequest {From : model , IgnoreRuntimeMemoryCheck : ignoreRuntimeMemoryCheck })
@@ -130,10 +275,14 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
130275 }
131276
132277 progressShown := false
133- current := uint64 (0 ) // Track cumulative progress across all layers
278+ // Track cumulative progress across all layers
279+ current := uint64 (0 )
134280 layerProgress := make (map [string ]uint64 ) // Track progress per layer ID
135281
136282 scanner := bufio .NewScanner (resp .Body )
283+ pbs := & ProgressBarState {
284+ UpdateInterval : time .Millisecond * 100 ,
285+ }
137286 for scanner .Scan () {
138287 progressLine := scanner .Text ()
139288 if progressLine == "" {
@@ -159,7 +308,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func
159308 current += layerCurrent
160309 }
161310
162- progress ( fmt . Sprintf ( "Downloaded %s of %s" , units . CustomSize ( "%.2f%s" , float64 ( current ), 1000.0 , [] string { "B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" }), units . CustomSize ( "%.2f%s" , float64 ( progressMsg .Total ), 1000.0 , [] string { "B" , "kB" , "MB" , "GB" , "TB" , "PB" , "EB" , "ZB" , "YB" })) )
311+ pbs . printProgressBar ( current , progressMsg .Total )
163312 progressShown = true
164313 case "error" :
165314 return "" , progressShown , fmt .Errorf ("error pulling model: %s" , progressMsg .Message )
0 commit comments