Skip to content

Commit 039b2a4

Browse files
committed
Reduce looping & memory usage in ProgressCombiner
1 parent 51d8348 commit 039b2a4

File tree

1 file changed

+63
-46
lines changed

1 file changed

+63
-46
lines changed

dandi/download.py

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,12 +1154,11 @@ def pairing(p: str, gen: Iterator[dict]) -> Iterator[tuple[str, dict]]:
11541154
yield (p, d)
11551155

11561156

1157-
DLState = Enum("DLState", "STARTING DOWNLOADING SKIPPED ERROR CHECKSUM_ERROR DONE")
1157+
DLState = Enum("DLState", "STARTING DOWNLOADING SKIPPED ERROR DONE")
11581158

11591159

11601160
@dataclass
11611161
class DownloadProgress:
1162-
state: DLState = DLState.STARTING
11631162
downloaded: int = 0
11641163
size: int | None = None
11651164

@@ -1170,46 +1169,48 @@ class ProgressCombiner:
11701169
file_qty: int | None = (
11711170
None # set to specific known value whenever full sweep is complete
11721171
)
1173-
files: dict[str, DownloadProgress] = field(default_factory=dict)
1172+
downloading: dict[str, DownloadProgress] = field(default_factory=dict)
1173+
#: The total number of bytes downloaded so far, including all files
1174+
#: currently downloading, skipped, or finished downloading (even if
1175+
#: the checksum check failed)
1176+
total_downloaded: int = 0
11741177
#: Total size of all files that were not skipped and did not error out
11751178
#: during download
11761179
maxsize: int = 0
11771180
prev_status: str = ""
11781181
yielded_size: bool = False
1182+
file_states: Counter[DLState] = field(default_factory=Counter)
1183+
files_fed: int = 0
11791184

11801185
def get_done(self) -> dict:
1181-
total_downloaded = sum(
1182-
s.downloaded
1183-
for s in self.files.values()
1184-
if s.state
1185-
in (
1186-
DLState.DOWNLOADING,
1187-
DLState.CHECKSUM_ERROR,
1188-
DLState.SKIPPED,
1189-
DLState.DONE,
1190-
)
1191-
)
11921186
return {
1193-
"done": total_downloaded,
1194-
"done%": total_downloaded / self.zarr_size * 100 if self.zarr_size else 0,
1187+
"done": self.total_downloaded,
1188+
"done%": (
1189+
self.total_downloaded / self.zarr_size * 100 if self.zarr_size else 0
1190+
),
11951191
}
11961192

11971193
def get_status(self, report_done: bool = True) -> dict:
1198-
state_qtys = Counter(s.state for s in self.files.values())
1199-
total = len(self.files)
12001194
if (
12011195
self.file_qty is not None # if already known
1202-
and total == self.file_qty
1203-
and state_qtys[DLState.STARTING] == state_qtys[DLState.DOWNLOADING] == 0
1196+
and self.files_fed == self.file_qty
1197+
and self.file_states[DLState.STARTING]
1198+
== self.file_states[DLState.DOWNLOADING]
1199+
== 0
12041200
):
12051201
# All files have finished
1206-
if state_qtys[DLState.ERROR] or state_qtys[DLState.CHECKSUM_ERROR]:
1202+
if self.file_states[DLState.ERROR]:
12071203
new_status = "error"
1208-
elif state_qtys[DLState.DONE]:
1204+
elif self.file_states[DLState.DONE]:
12091205
new_status = "done"
12101206
else:
12111207
new_status = "skipped"
1212-
elif total - state_qtys[DLState.STARTING] - state_qtys[DLState.SKIPPED] > 0:
1208+
elif (
1209+
self.files_fed
1210+
- self.file_states[DLState.STARTING]
1211+
- self.file_states[DLState.SKIPPED]
1212+
> 0
1213+
):
12131214
new_status = "downloading"
12141215
else:
12151216
new_status = ""
@@ -1218,12 +1219,12 @@ def get_status(self, report_done: bool = True) -> dict:
12181219

12191220
if report_done:
12201221
msg_comps = []
1221-
for msg_label, states in {
1222-
"done": (DLState.DONE,),
1223-
"errored": (DLState.ERROR, DLState.CHECKSUM_ERROR),
1224-
"skipped": (DLState.SKIPPED,),
1225-
}.items():
1226-
if count := sum(state_qtys.get(state, 0) for state in states):
1222+
for msg_label, state in [
1223+
("done", DLState.DONE),
1224+
("errored", DLState.ERROR),
1225+
("skipped", DLState.SKIPPED),
1226+
]:
1227+
if count := self.file_states[state]:
12271228
msg_comps.append(f"{count} {msg_label}")
12281229
if msg_comps:
12291230
statusdict["message"] = ", ".join(msg_comps)
@@ -1238,49 +1239,65 @@ def get_status(self, report_done: bool = True) -> dict:
12381239

12391240
def feed(self, path: str, status: dict) -> Iterator[dict]:
12401241
keys = list(status.keys())
1241-
self.files.setdefault(path, DownloadProgress())
12421242
size = status.get("size")
12431243
if size is not None:
12441244
if not self.yielded_size:
12451245
# this thread will yield
12461246
self.yielded_size = True
12471247
yield {"size": self.zarr_size}
12481248
if status.get("status") == "skipped":
1249-
self.files[path].state = DLState.SKIPPED
1249+
self.files_fed += 1
1250+
self.file_states[DLState.SKIPPED] += 1
1251+
try:
1252+
self.total_downloaded -= self.downloading.pop(path).downloaded
1253+
except KeyError:
1254+
pass
12501255
# Treat skipped as "downloaded" for the matter of accounting
12511256
if size is not None:
1252-
self.files[path].downloaded = size
1257+
self.total_downloaded += size
12531258
self.maxsize += size
12541259
yield self.get_status()
12551260
elif keys == ["size"]:
1256-
self.files[path].size = size
1257-
self.maxsize += status["size"]
1258-
if any(s.state is DLState.DOWNLOADING for s in self.files.values()):
1261+
self.files_fed += 1
1262+
self.file_states[DLState.STARTING] += 1
1263+
assert size is not None
1264+
self.downloading[path] = DownloadProgress(size=size)
1265+
self.maxsize += size
1266+
if self.file_states[DLState.DOWNLOADING]:
12591267
yield self.get_done()
12601268
elif status == {"status": "downloading"}:
1261-
self.files[path].state = DLState.DOWNLOADING
1269+
self.file_states[DLState.DOWNLOADING] += 1
1270+
if path not in self.downloading:
1271+
self.files_fed += 1
1272+
self.downloading[path] = DownloadProgress()
1273+
else:
1274+
self.file_states[DLState.STARTING] -= 1
12621275
if out := self.get_status(report_done=False):
12631276
yield out
12641277
elif "done" in status:
1265-
self.files[path].downloaded = status["done"]
1278+
prev_done = self.downloading[path].downloaded
1279+
self.total_downloaded += status["done"] - prev_done
1280+
self.downloading[path].downloaded = status["done"]
12661281
yield self.get_done()
12671282
elif status.get("status") == "error":
1268-
if "checksum" in status:
1269-
self.files[path].state = DLState.CHECKSUM_ERROR
1270-
else:
1271-
self.files[path].state = DLState.ERROR
1272-
sz = self.files[path].size
1273-
if sz is not None:
1274-
self.maxsize -= sz
1283+
self.file_states[DLState.DOWNLOADING] -= 1
1284+
self.file_states[DLState.ERROR] += 1
1285+
progress = self.downloading.pop(path)
1286+
if "checksum" not in status:
1287+
if progress.size is not None:
1288+
self.maxsize -= progress.size
1289+
self.total_downloaded -= progress.downloaded
12751290
yield self.get_status()
12761291
elif keys == ["checksum"]:
12771292
pass
12781293
elif status == {"status": "setting mtime"}:
12791294
pass
12801295
elif status == {"status": "done"}:
1281-
self.files[path].state = DLState.DONE
1296+
del self.downloading[path]
1297+
self.file_states[DLState.DOWNLOADING] -= 1
1298+
self.file_states[DLState.DONE] += 1
12821299
yield self.get_status()
12831300
else:
12841301
lgr.warning(
1285-
"Unexpected download status dict for %r received: %r", path, status
1302+
"Unexpected download status dict received for %r: %r", path, status
12861303
)

0 commit comments

Comments
 (0)