|
20 | 20 | import docker.types |
21 | 21 | from dataclasses_json import dataclass_json |
22 | 22 | from docker.utils import parse_repository_tag |
| 23 | +from tqdm import tqdm |
23 | 24 |
|
24 | 25 | from .. import envs |
25 | 26 | from ..logging import debug_log_exception |
|
42 | 43 | WorkloadStatusOperation, |
43 | 44 | WorkloadStatusStateEnum, |
44 | 45 | ) |
45 | | -from .__utils__ import _MiB, bytes_to_human_readable, safe_json |
| 46 | +from .__utils__ import safe_json |
46 | 47 |
|
47 | 48 | if TYPE_CHECKING: |
48 | 49 | from collections.abc import Callable, Generator |
@@ -466,52 +467,53 @@ def _pull_image(self, image: str) -> docker.models.images.Image: |
466 | 467 | auth_config=auth_config, |
467 | 468 | ) |
468 | 469 |
|
469 | | - progress_threshold = 1 |
470 | | - progress_in_percent = None |
471 | | - progress: int = 0 |
472 | | - progress_current: int = 0 |
473 | | - layers: dict[str, int] = {} |
474 | | - layer_progress: dict[str, int] = {} |
475 | | - layer_progress_current: dict[str, int] = {} |
| 470 | + layers: dict[str, tqdm] = {} |
| 471 | + |
| 472 | + def clean_layers(): |
| 473 | + if not layers: |
| 474 | + return |
| 475 | + for layer in layers.values(): |
| 476 | + layer.close() |
| 477 | + layers.clear() |
| 478 | + |
476 | 479 | for log in logs: |
477 | 480 | if "id" not in log: |
| 481 | + clean_layers() |
478 | 482 | logger.info(log["status"]) |
479 | 483 | continue |
480 | | - log_id = log["id"] |
481 | | - if log_id not in layers: |
482 | | - layers[log_id] = len(layers) |
483 | | - layer_progress[log_id] = 0 |
484 | | - layer_progress_current[log_id] = 0 |
485 | | - p_c = log.get("progressDetail", {}).get("current") |
486 | | - p_t = log.get("progressDetail", {}).get("total") |
487 | | - if ( |
488 | | - (progress_in_percent is None or progress_in_percent) |
489 | | - and p_c is not None |
490 | | - and p_t is not None |
491 | | - ): |
492 | | - progress_in_percent = True |
493 | | - layer_progress[log_id] = int(p_c * 100 // p_t) |
494 | | - p_diff = ( |
495 | | - sum(layer_progress.values()) |
496 | | - * 100 |
497 | | - // (len(layer_progress) * 100) |
498 | | - - progress |
| 484 | + |
| 485 | + layer_id = log.get("id") |
| 486 | + layer_status = log.get("status", "") |
| 487 | + layer_progress = log.get("progressDetail", {}) |
| 488 | + layer_progress_total = layer_progress.get("total", None) |
| 489 | + layer_progress_current = layer_progress.get("current", None) |
| 490 | + if layer_id not in layers: |
| 491 | + layers[layer_id] = tqdm( |
| 492 | + unit="B", |
| 493 | + unit_scale=True, |
| 494 | + position=len(layers), |
| 495 | + ncols=70, |
| 496 | + desc=f"{layer_id}: {layer_status}", |
| 497 | + bar_format="{desc}", |
499 | 498 | ) |
500 | | - if p_diff >= progress_threshold: |
501 | | - progress += progress_threshold |
502 | | - progress_threshold = min(5, progress_threshold + 1) |
503 | | - logger.info(f"Pulling image {image}: {progress}%") |
504 | | - elif not progress_in_percent and p_c is not None: |
505 | | - progress_in_percent = False |
506 | | - layer_progress_current[log_id] = p_c |
507 | | - p_c_total = sum(layer_progress_current.values()) |
508 | | - p_diff = p_c_total - progress_current |
509 | | - if p_diff >= progress_threshold * _MiB: |
510 | | - progress_current = p_c_total |
511 | | - progress_threshold = min(200, progress_threshold + 2) |
512 | | - logger.info( |
513 | | - f"Pulling image {image}: {bytes_to_human_readable(p_c_total)}", |
514 | | - ) |
| 499 | + else: |
| 500 | + layers[layer_id].desc = f"{layer_id}: {layer_status}" |
| 501 | + |
| 502 | + if layer_progress_total is not None: |
| 503 | + layers[layer_id].total = layer_progress_total |
| 504 | + bf = "{desc} |{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]" |
| 505 | + layers[layer_id].bar_format = bf |
| 506 | + elif layer_progress_current is not None: |
| 507 | + layers[layer_id].bar_format = "{desc} {n_fmt} [{rate_fmt}{postfix}]" |
| 508 | + else: |
| 509 | + layers[layer_id].bar_format = "{desc}" |
| 510 | + |
| 511 | + if layer_progress_current: |
| 512 | + layers[layer_id].n = layer_progress_current |
| 513 | + |
| 514 | + layers[layer_id].refresh() |
| 515 | + |
| 516 | + clean_layers() |
515 | 517 |
|
516 | 518 | sep = "@" if tag.startswith("sha256:") else ":" |
517 | 519 | return self._client.images.get(f"{repo}{sep}{tag}") |
|
0 commit comments