|
1 | 1 | import os
|
2 | 2 | from dataclasses import dataclass
|
3 | 3 | from pathlib import Path
|
4 |
| -from typing import Any, Dict, Iterable, List |
| 4 | +from typing import Any, Iterable, List |
5 | 5 |
|
6 | 6 | import click
|
7 | 7 | import structlog
|
|
10 | 10 | from IPython.core.magic_arguments import argument, magic_arguments
|
11 | 11 | from IPython.utils.process import arg_split
|
12 | 12 | from rich import print as rprint
|
13 |
| -from rich.progress import Progress, TaskID |
| 13 | +from tqdm.auto import tqdm |
14 | 14 | from traitlets import Float, Unicode
|
15 | 15 | from traitlets.config import Configurable
|
16 | 16 |
|
@@ -142,29 +142,27 @@ def process_file_update_stream(path: str, stream: DatasetOperationStream):
|
142 | 142 | error_message = None
|
143 | 143 | complete_message = None
|
144 | 144 |
|
145 |
| - with Progress() as progress: |
146 |
| - tasks_by_file_path: Dict[str, TaskID] = {} |
| 145 | + progress_bars = {} |
147 | 146 |
|
| 147 | + try: |
148 | 148 | for msg in stream:
|
149 | 149 | if isinstance(msg, StreamErrorMessage):
|
150 | 150 | error_message = msg.content.detail
|
151 | 151 | break
|
152 | 152 | elif isinstance(msg, FileProgressUpdateMessage):
|
153 | 153 | got_file_update_msg = True
|
154 | 154 |
|
155 |
| - if msg.content.file_name not in tasks_by_file_path: |
156 |
| - tasks_by_file_path[msg.content.file_name] = progress.add_task( |
157 |
| - msg.content.file_name, total=100.0 |
158 |
| - ) |
| 155 | + if msg.content.file_name not in progress_bars: |
| 156 | + progress_bars[msg.content.file_name] = tqdm(total=100.0, desc=msg.content.file_name) |
159 | 157 |
|
160 |
| - progress.update( |
161 |
| - tasks_by_file_path[msg.content.file_name], |
162 |
| - completed=msg.content.percent_complete * 100.0, |
163 |
| - ) |
| 158 | + progress_bars[msg.content.file_name].update(msg.content.percent_complete * 100.0) |
164 | 159 | elif isinstance(msg, FileProgressStartMessage):
|
165 |
| - progress.console.print(msg.content.message) |
| 160 | + print(msg.content.message) |
166 | 161 | elif isinstance(msg, FileProgressEndMessage) and got_file_update_msg:
|
167 | 162 | complete_message = msg.content.message
|
| 163 | + finally: |
| 164 | + for bar in progress_bars.values(): |
| 165 | + bar.close() |
168 | 166 |
|
169 | 167 | if error_message:
|
170 | 168 | rprint(f"[red]{error_message}[/red]")
|
|
0 commit comments