Skip to content

Commit a9796bc

Browse files
committed
refactor: adjust pull progress
Signed-off-by: thxCode <[email protected]>
1 parent fb4c683 commit a9796bc

File tree

2 files changed

+171
-61
lines changed

2 files changed

+171
-61
lines changed

gpustack_runtime/deployer/docker.py

Lines changed: 153 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import operator
77
import os
88
import socket
9+
import sys
910
from dataclasses import dataclass, field
1011
from functools import lru_cache, reduce
1112
from math import ceil
@@ -43,7 +44,7 @@
4344
WorkloadStatusOperation,
4445
WorkloadStatusStateEnum,
4546
)
46-
from .__utils__ import safe_json
47+
from .__utils__ import _MiB, bytes_to_human_readable, safe_json
4748

4849
if TYPE_CHECKING:
4950
from collections.abc import Callable, Generator
@@ -444,9 +445,9 @@ def _create_ephemeral_volumes(self, workload: DockerWorkloadPlan) -> dict[str, s
444445
return ephemeral_volume_name_mapping
445446

446447
def _pull_image(self, image: str) -> docker.models.images.Image:
447-
logger.info(f"Pulling image {image}")
448-
449448
try:
449+
logger.info("Pulling image %s", image)
450+
450451
repo, tag = parse_repository_tag(image)
451452
tag = tag or "latest"
452453
auth_config = None
@@ -466,54 +467,9 @@ def _pull_image(self, image: str) -> docker.models.images.Image:
466467
decode=True,
467468
auth_config=auth_config,
468469
)
470+
_print_pull_logs(logs, image, tag)
469471

470-
layers: dict[str, tqdm] = {}
471-
472-
def clean_layers():
473-
if not layers:
474-
return
475-
for layer in layers.values():
476-
layer.close()
477-
layers.clear()
478-
479-
for log in logs:
480-
if "id" not in log:
481-
clean_layers()
482-
logger.info(log["status"])
483-
continue
484-
485-
layer_id = log.get("id")
486-
layer_status = log.get("status", "")
487-
layer_progress = log.get("progressDetail", {})
488-
layer_progress_total = layer_progress.get("total", None)
489-
layer_progress_current = layer_progress.get("current", None)
490-
if layer_id not in layers:
491-
layers[layer_id] = tqdm(
492-
unit="B",
493-
unit_scale=True,
494-
position=len(layers),
495-
ncols=70,
496-
desc=f"{layer_id}: {layer_status}",
497-
bar_format="{desc}",
498-
)
499-
else:
500-
layers[layer_id].desc = f"{layer_id}: {layer_status}"
501-
502-
if layer_progress_total is not None:
503-
layers[layer_id].total = layer_progress_total
504-
bf = "{desc} |{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]"
505-
layers[layer_id].bar_format = bf
506-
elif layer_progress_current is not None:
507-
layers[layer_id].bar_format = "{desc} {n_fmt} [{rate_fmt}{postfix}]"
508-
else:
509-
layers[layer_id].bar_format = "{desc}"
510-
511-
if layer_progress_current:
512-
layers[layer_id].n = layer_progress_current
513-
514-
layers[layer_id].refresh()
515-
516-
clean_layers()
472+
logger.info("Pulled image %s", image)
517473

518474
sep = "@" if tag.startswith("sha256:") else ":"
519475
return self._client.images.get(f"{repo}{sep}{tag}")
@@ -1959,3 +1915,150 @@ def _detail_api_call_error(err: docker.errors.APIError) -> str:
19591915
msg += f": status code {err.response.status_code}"
19601916

19611917
return msg
1918+
1919+
1920+
def _print_pull_logs(logs, image, tag):
1921+
"""
1922+
Display Docker image pull logs.
1923+
1924+
Args:
1925+
logs:
1926+
The logs from Docker image pull.
1927+
image:
1928+
The image being pulled.
1929+
tag:
1930+
The image tag being pulled.
1931+
1932+
"""
1933+
if (
1934+
not envs.GPUSTACK_RUNTIME_DOCKER_IMAGE_NO_PULL_VISUALIZATION
1935+
and sys.stderr.isatty()
1936+
):
1937+
_visualize_pull_logs(logs, tag)
1938+
else:
1939+
_textualize_pull_logs(logs, image, tag)
1940+
1941+
1942+
def _visualize_pull_logs(logs, tag):
1943+
"""
1944+
Display Docker image pull logs as progress bars.
1945+
1946+
Args:
1947+
logs:
1948+
The logs from Docker image pull.
1949+
tag:
1950+
The image tag being pulled.
1951+
1952+
"""
1953+
pbars: dict[str, tqdm] = {}
1954+
dmsgs: list[str] = []
1955+
1956+
try:
1957+
for log in logs:
1958+
id_ = log.get("id", None)
1959+
status = log.get("status", "")
1960+
if not id_:
1961+
dmsgs.append(status)
1962+
continue
1963+
if id_ == tag:
1964+
continue
1965+
1966+
progress = log.get("progressDetail", {})
1967+
progress_total = progress.get("total", None)
1968+
progress_current = progress.get("current", None)
1969+
1970+
if id_ not in pbars:
1971+
pbars[id_] = tqdm(
1972+
unit="B",
1973+
unit_scale=True,
1974+
desc=f"{id_}: {status}",
1975+
bar_format="{desc}",
1976+
)
1977+
continue
1978+
1979+
pbars[id_].desc = f"{id_}: {status}"
1980+
if progress_total is not None:
1981+
pbars[id_].total = progress_total
1982+
bf = "{desc} |{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]"
1983+
pbars[id_].bar_format = bf
1984+
elif progress_current is not None:
1985+
pbars[id_].bar_format = "{desc} {n_fmt} [{rate_fmt}{postfix}]"
1986+
else:
1987+
pbars[id_].bar_format = "{desc}"
1988+
1989+
if progress_current:
1990+
pbars[id_].n = progress_current
1991+
1992+
pbars[id_].refresh()
1993+
finally:
1994+
for pbar in pbars.values():
1995+
pbar.close()
1996+
pbars.clear()
1997+
1998+
for msg in dmsgs:
1999+
print(msg, flush=True)
2000+
2001+
2002+
def _textualize_pull_logs(logs, image, tag):
2003+
"""
2004+
Display Docker image pull logs as plain text.
2005+
2006+
Args:
2007+
logs:
2008+
The logs from Docker image pull.
2009+
image:
2010+
The image being pulled.
2011+
tag:
2012+
The image tag being pulled.
2013+
2014+
"""
2015+
pstats: dict[str, tuple[int, int]] = {}
2016+
pstats_cursor: int = 0
2017+
pstats_cursor_move: int = 1
2018+
dmsgs: list[str] = []
2019+
2020+
for log in logs:
2021+
id_ = log.get("id", None)
2022+
status = log.get("status", "")
2023+
if not id_:
2024+
dmsgs.append(status)
2025+
continue
2026+
if id_ == tag:
2027+
continue
2028+
2029+
if id_ not in pstats:
2030+
pstats[id_] = (0, 0)
2031+
continue
2032+
2033+
progress = log.get("progressDetail", {})
2034+
progress_total = progress.get("total", None)
2035+
progress_current = progress.get("current", None)
2036+
2037+
if progress_total is not None or progress_current is not None:
2038+
pstats[id_] = (progress_total or 0, progress_current or 0)
2039+
2040+
pstats_total, pstats_current = 0, 0
2041+
for t, c in pstats.values():
2042+
pstats_total += t
2043+
pstats_current += c
2044+
2045+
if pstats_total:
2046+
pstats_cursor_diff = int(
2047+
pstats_current * 100 // pstats_total - pstats_cursor,
2048+
)
2049+
if pstats_cursor_diff >= pstats_cursor_move and pstats_cursor < 100:
2050+
pstats_cursor += pstats_cursor_diff
2051+
pstats_cursor_move = min(5, pstats_cursor_move + 1)
2052+
print(f"Pulling image {image}: {pstats_cursor}%", flush=True)
2053+
elif pstats_current:
2054+
pstats_cursor_diff = int(
2055+
pstats_current - pstats_cursor,
2056+
)
2057+
if pstats_cursor_diff >= pstats_cursor_move:
2058+
pstats_cursor += pstats_cursor_diff
2059+
pstats_cursor_move = min(200 * _MiB, pstats_cursor_move + 2 * _MiB)
2060+
pstats_cursor_human = bytes_to_human_readable(pstats_cursor)
2061+
print(f"Pulling image {image}: {pstats_cursor_human}", flush=True)
2062+
2063+
for msg in dmsgs:
2064+
print(msg, flush=True)

gpustack_runtime/envs.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@
5353
The detected backend mapping to resource keys,
5454
e.g `{"cuda": "nvidia.com/devices", "rocm": "amd.com/devices"}`.
5555
"""
56+
GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY: bool = True
57+
"""
58+
Use physical index priority at detecting devices.
59+
"""
5660
## Deployer
5761
GPUSTACK_RUNTIME_DEPLOY: str | None = None
5862
"""
@@ -174,13 +178,6 @@
174178
alignment is performed to ensure they are correctly identified.
175179
"""
176180

177-
# Detector
178-
179-
GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY: bool = True
180-
"""
181-
Use physical index priority at detecting devices.
182-
"""
183-
184181
# Deployer
185182

186183
## Docker
@@ -190,6 +187,10 @@
190187
Only works when `GPUSTACK_RUNTIME_DEPLOY_MIRRORED_NAME` is not set.
191188
Normally, it should be injected automatically via CI without any manual configuration.
192189
"""
190+
GPUSTACK_RUNTIME_DOCKER_IMAGE_NO_PULL_VISUALIZATION: bool = False
191+
"""
192+
Disable image pull visualization in Docker deployer.
193+
"""
193194
GPUSTACK_RUNTIME_DOCKER_PAUSE_IMAGE: str | None = None
194195
"""
195196
Docker image used for the pause container.
@@ -246,6 +247,7 @@
246247
"GPUSTACK_RUNTIME_LOG_EXCEPTION": lambda: to_bool(
247248
getenv("GPUSTACK_RUNTIME_LOG_EXCEPTION", "1"),
248249
),
250+
## Detector
249251
"GPUSTACK_RUNTIME_DETECT": lambda: getenv(
250252
"GPUSTACK_RUNTIME_DETECT",
251253
"Auto",
@@ -271,6 +273,10 @@
271273
"cuda=nvidia.com/devices;",
272274
),
273275
),
276+
"GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY": lambda: to_bool(
277+
getenv("GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY", "1"),
278+
),
279+
## Deployer
274280
"GPUSTACK_RUNTIME_DEPLOY": lambda: getenv(
275281
"GPUSTACK_RUNTIME_DEPLOY",
276282
"Auto",
@@ -373,11 +379,8 @@
373379
),
374380
sep=",",
375381
),
376-
# Detector
377-
"GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY": lambda: to_bool(
378-
getenv("GPUSTACK_RUNTIME_DETECT_PHYSICAL_INDEX_PRIORITY", "1"),
379-
),
380382
# Deployer
383+
## Docker
381384
"GPUSTACK_RUNTIME_DOCKER_MIRRORED_NAME_FILTER_LABELS": lambda: to_dict(
382385
getenv(
383386
"GPUSTACK_RUNTIME_DOCKER_MIRRORED_NAME_FILTER_LABELS",
@@ -401,6 +404,10 @@
401404
"GPUSTACK_RUNTIME_DOCKER_MUTE_ORIGINAL_HEALTHCHECK": lambda: to_bool(
402405
getenv("GPUSTACK_RUNTIME_DOCKER_MUTE_ORIGINAL_HEALTHCHECK", "1"),
403406
),
407+
"GPUSTACK_RUNTIME_DOCKER_IMAGE_NO_PULL_VISUALIZATION": lambda: to_bool(
408+
getenv("GPUSTACK_RUNTIME_DOCKER_IMAGE_NO_PULL_VISUALIZATION", "0"),
409+
),
410+
## Kubernetes
404411
"GPUSTACK_RUNTIME_KUBERNETES_NODE_NAME": lambda: getenv(
405412
"GPUSTACK_RUNTIME_KUBERNETES_NODE_NAME",
406413
None,

0 commit comments

Comments
 (0)