66import operator
77import os
88import socket
9+ import sys
910from dataclasses import dataclass , field
1011from functools import lru_cache , reduce
1112from math import ceil
4344 WorkloadStatusOperation ,
4445 WorkloadStatusStateEnum ,
4546)
46- from .__utils__ import safe_json
47+ from .__utils__ import _MiB , bytes_to_human_readable , safe_json
4748
4849if TYPE_CHECKING :
4950 from collections .abc import Callable , Generator
@@ -444,9 +445,9 @@ def _create_ephemeral_volumes(self, workload: DockerWorkloadPlan) -> dict[str, s
444445 return ephemeral_volume_name_mapping
445446
446447 def _pull_image (self , image : str ) -> docker .models .images .Image :
447- logger .info (f"Pulling image { image } " )
448-
449448 try :
449+ logger .info ("Pulling image %s" , image )
450+
450451 repo , tag = parse_repository_tag (image )
451452 tag = tag or "latest"
452453 auth_config = None
@@ -466,54 +467,9 @@ def _pull_image(self, image: str) -> docker.models.images.Image:
466467 decode = True ,
467468 auth_config = auth_config ,
468469 )
470+ _print_pull_logs (logs , image , tag )
469471
470- layers : dict [str , tqdm ] = {}
471-
472- def clean_layers ():
473- if not layers :
474- return
475- for layer in layers .values ():
476- layer .close ()
477- layers .clear ()
478-
479- for log in logs :
480- if "id" not in log :
481- clean_layers ()
482- logger .info (log ["status" ])
483- continue
484-
485- layer_id = log .get ("id" )
486- layer_status = log .get ("status" , "" )
487- layer_progress = log .get ("progressDetail" , {})
488- layer_progress_total = layer_progress .get ("total" , None )
489- layer_progress_current = layer_progress .get ("current" , None )
490- if layer_id not in layers :
491- layers [layer_id ] = tqdm (
492- unit = "B" ,
493- unit_scale = True ,
494- position = len (layers ),
495- ncols = 70 ,
496- desc = f"{ layer_id } : { layer_status } " ,
497- bar_format = "{desc}" ,
498- )
499- else :
500- layers [layer_id ].desc = f"{ layer_id } : { layer_status } "
501-
502- if layer_progress_total is not None :
503- layers [layer_id ].total = layer_progress_total
504- bf = "{desc} |{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]"
505- layers [layer_id ].bar_format = bf
506- elif layer_progress_current is not None :
507- layers [layer_id ].bar_format = "{desc} {n_fmt} [{rate_fmt}{postfix}]"
508- else :
509- layers [layer_id ].bar_format = "{desc}"
510-
511- if layer_progress_current :
512- layers [layer_id ].n = layer_progress_current
513-
514- layers [layer_id ].refresh ()
515-
516- clean_layers ()
472+ logger .info ("Pulled image %s" , image )
517473
518474 sep = "@" if tag .startswith ("sha256:" ) else ":"
519475 return self ._client .images .get (f"{ repo } { sep } { tag } " )
@@ -1959,3 +1915,150 @@ def _detail_api_call_error(err: docker.errors.APIError) -> str:
19591915 msg += f": status code { err .response .status_code } "
19601916
19611917 return msg
1918+
1919+
1920+ def _print_pull_logs (logs , image , tag ):
1921+ """
1922+ Display Docker image pull logs.
1923+
1924+ Args:
1925+ logs:
1926+ The logs from Docker image pull.
1927+ image:
1928+ The image being pulled.
1929+ tag:
1930+ The image tag being pulled.
1931+
1932+ """
1933+ if (
1934+ not envs .GPUSTACK_RUNTIME_DOCKER_IMAGE_NO_PULL_VISUALIZATION
1935+ and sys .stderr .isatty ()
1936+ ):
1937+ _visualize_pull_logs (logs , tag )
1938+ else :
1939+ _textualize_pull_logs (logs , image , tag )
1940+
1941+
1942+ def _visualize_pull_logs (logs , tag ):
1943+ """
1944+ Display Docker image pull logs as progress bars.
1945+
1946+ Args:
1947+ logs:
1948+ The logs from Docker image pull.
1949+ tag:
1950+ The image tag being pulled.
1951+
1952+ """
1953+ pbars : dict [str , tqdm ] = {}
1954+ dmsgs : list [str ] = []
1955+
1956+ try :
1957+ for log in logs :
1958+ id_ = log .get ("id" , None )
1959+ status = log .get ("status" , "" )
1960+ if not id_ :
1961+ dmsgs .append (status )
1962+ continue
1963+ if id_ == tag :
1964+ continue
1965+
1966+ progress = log .get ("progressDetail" , {})
1967+ progress_total = progress .get ("total" , None )
1968+ progress_current = progress .get ("current" , None )
1969+
1970+ if id_ not in pbars :
1971+ pbars [id_ ] = tqdm (
1972+ unit = "B" ,
1973+ unit_scale = True ,
1974+ desc = f"{ id_ } : { status } " ,
1975+ bar_format = "{desc}" ,
1976+ )
1977+ continue
1978+
1979+ pbars [id_ ].desc = f"{ id_ } : { status } "
1980+ if progress_total is not None :
1981+ pbars [id_ ].total = progress_total
1982+ bf = "{desc} |{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]"
1983+ pbars [id_ ].bar_format = bf
1984+ elif progress_current is not None :
1985+ pbars [id_ ].bar_format = "{desc} {n_fmt} [{rate_fmt}{postfix}]"
1986+ else :
1987+ pbars [id_ ].bar_format = "{desc}"
1988+
1989+ if progress_current :
1990+ pbars [id_ ].n = progress_current
1991+
1992+ pbars [id_ ].refresh ()
1993+ finally :
1994+ for pbar in pbars .values ():
1995+ pbar .close ()
1996+ pbars .clear ()
1997+
1998+ for msg in dmsgs :
1999+ print (msg , flush = True )
2000+
2001+
2002+ def _textualize_pull_logs (logs , image , tag ):
2003+ """
2004+ Display Docker image pull logs as plain text.
2005+
2006+ Args:
2007+ logs:
2008+ The logs from Docker image pull.
2009+ image:
2010+ The image being pulled.
2011+ tag:
2012+ The image tag being pulled.
2013+
2014+ """
2015+ pstats : dict [str , tuple [int , int ]] = {}
2016+ pstats_cursor : int = 0
2017+ pstats_cursor_move : int = 1
2018+ dmsgs : list [str ] = []
2019+
2020+ for log in logs :
2021+ id_ = log .get ("id" , None )
2022+ status = log .get ("status" , "" )
2023+ if not id_ :
2024+ dmsgs .append (status )
2025+ continue
2026+ if id_ == tag :
2027+ continue
2028+
2029+ if id_ not in pstats :
2030+ pstats [id_ ] = (0 , 0 )
2031+ continue
2032+
2033+ progress = log .get ("progressDetail" , {})
2034+ progress_total = progress .get ("total" , None )
2035+ progress_current = progress .get ("current" , None )
2036+
2037+ if progress_total is not None or progress_current is not None :
2038+ pstats [id_ ] = (progress_total or 0 , progress_current or 0 )
2039+
2040+ pstats_total , pstats_current = 0 , 0
2041+ for t , c in pstats .values ():
2042+ pstats_total += t
2043+ pstats_current += c
2044+
2045+ if pstats_total :
2046+ pstats_cursor_diff = int (
2047+ pstats_current * 100 // pstats_total - pstats_cursor ,
2048+ )
2049+ if pstats_cursor_diff >= pstats_cursor_move and pstats_cursor < 100 :
2050+ pstats_cursor += pstats_cursor_diff
2051+ pstats_cursor_move = min (5 , pstats_cursor_move + 1 )
2052+ print (f"Pulling image { image } : { pstats_cursor } %" , flush = True )
2053+ elif pstats_current :
2054+ pstats_cursor_diff = int (
2055+ pstats_current - pstats_cursor ,
2056+ )
2057+ if pstats_cursor_diff >= pstats_cursor_move :
2058+ pstats_cursor += pstats_cursor_diff
2059+ pstats_cursor_move = min (200 * _MiB , pstats_cursor_move + 2 * _MiB )
2060+ pstats_cursor_human = bytes_to_human_readable (pstats_cursor )
2061+ print (f"Pulling image { image } : { pstats_cursor_human } " , flush = True )
2062+
2063+ for msg in dmsgs :
2064+ print (msg , flush = True )
0 commit comments