Skip to content
This repository was archived by the owner on Dec 27, 2022. It is now read-only.

Commit 585e650

Browse files
authored
Merge pull request #42 from richford/fix-pbar-and-int-subjects
Fix subjects=int bug and make the download progress bars better
2 parents 4e3cd12 + d062836 commit 585e650

File tree

1 file changed

+60
-50
lines changed

1 file changed

+60
-50
lines changed

dmriprep/data.py

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -252,20 +252,18 @@ def list_all_subjects(self):
252252
"""
253253
raise NotImplementedError
254254

255-
def postprocess(self, subject, pbar):
255+
def postprocess(self, subject):
256256
"""Study-specific postprocessing steps
257257
258258
Parameters
259259
----------
260260
subject : dmriprep.data.Subject
261261
subject instance
262-
263-
pbar : bool, default=True
264-
If True, include progress bar
265262
"""
266263
raise NotImplementedError
267264

268-
def download(self, directory, include_site=False, overwrite=False):
265+
def download(self, directory, include_site=False, overwrite=False,
266+
pbar=True):
269267
"""Download files for each subject in the study
270268
271269
Parameters
@@ -279,6 +277,9 @@ def download(self, directory, include_site=False, overwrite=False):
279277
overwrite : bool, default=False
280278
If True, overwrite files for each subject
281279
280+
pbar : bool, default=True
281+
If True, include progress bar
282+
282283
See Also
283284
--------
284285
dmriprep.data.Subject.download()
@@ -287,11 +288,11 @@ def download(self, directory, include_site=False, overwrite=False):
287288
directory=directory,
288289
include_site=include_site,
289290
overwrite=overwrite,
290-
pbar=False
291-
) for sub in self.subjects]
291+
pbar=pbar,
292+
pbar_idx=idx,
293+
) for idx, sub in enumerate(self.subjects)]
292294

293-
with ProgressBar():
294-
compute(*results, scheduler="threads")
295+
compute(*results, scheduler="threads")
295296

296297

297298
class HBN(Study):
@@ -362,7 +363,7 @@ def get_subs_from_tsv_key(s3_key):
362363

363364
return all_subjects
364365

365-
def postprocess(self, subject, pbar):
366+
def postprocess(self, subject):
366367
"""Move the T1 file back into the freesurfer directory.
367368
368369
This step is specific to the HBN dataset where the T1 files
@@ -373,17 +374,8 @@ def postprocess(self, subject, pbar):
373374
----------
374375
subject : dmriprep.data.Subject
375376
subject instance
376-
377-
pbar : bool, default=True
378-
If True, include progress bar
379377
"""
380-
if pbar:
381-
sessions_pbar = tqdm(subject.files.keys(),
382-
desc="Postprocess mriconvert T1W")
383-
else:
384-
sessions_pbar = subject.files.keys()
385-
386-
for sess in sessions_pbar:
378+
for sess in subject.files.keys():
387379
t1_file = subject.files[sess]['t1w'][0]
388380
freesurfer_path = op.join(op.dirname(t1_file), 'freesurfer')
389381

@@ -424,8 +416,9 @@ def __init__(self, subject_id, study, site=None):
424416
self._site = site
425417
self._valid = False
426418
self._organize_s3_keys()
427-
self._s3_keys = self._determine_directions(self._s3_keys)
428-
self._files = None
419+
if self.valid:
420+
self._s3_keys = self._determine_directions(self._s3_keys)
421+
self._files = None
429422

430423
@property
431424
def subject_id(self):
@@ -557,7 +550,7 @@ def _organize_s3_keys(self):
557550
self._s3_keys = None
558551

559552
def download(self, directory, include_site=False,
560-
overwrite=False, pbar=True):
553+
overwrite=False, pbar=True, pbar_idx=0):
561554
"""Download files from S3
562555
563556
Parameters
@@ -573,7 +566,17 @@ def download(self, directory, include_site=False,
573566
574567
pbar : bool, default=True
575568
If True, include download progress bar
569+
570+
pbar_idx : int, default=0
571+
Progress bar index for multithreaded progress bars
576572
"""
573+
if not self.valid:
574+
mod_logger.warning(
575+
f"Subject {self.subject_id} is not a valid subject. "
576+
f"Skipping download."
577+
)
578+
return
579+
577580
if include_site:
578581
directory = op.join(directory, self.site)
579582

@@ -583,36 +586,16 @@ def download(self, directory, include_site=False,
583586
)) for p in v] for k, v in self.s3_keys.items()
584587
}
585588

586-
if pbar:
587-
pbar_ftypes = tqdm(self.s3_keys.keys(),
588-
desc=f"Downloading {self.subject_id}")
589-
else:
590-
pbar_ftypes = self.s3_keys.keys()
589+
# Generate list of (key, file) tuples
590+
key_file_pairs = []
591591

592-
for ftype in pbar_ftypes:
593-
if pbar:
594-
pbar_ftypes.set_description(
595-
f"Downloading {self.subject_id} ({ftype})"
596-
)
592+
for ftype in self.s3_keys.keys():
597593
s3_keys = self.s3_keys[ftype]
598594
if isinstance(s3_keys, str):
599-
_download_from_s3(fname=files[ftype],
600-
bucket=self.study.bucket,
601-
key=s3_keys,
602-
overwrite=overwrite)
595+
key_file_pairs.append((s3_keys, files[ftype]))
603596
elif all(isinstance(x, str) for x in s3_keys):
604-
if pbar:
605-
file_zip = tqdm(zip(s3_keys, files[ftype]),
606-
desc=f"{ftype}",
607-
total=len(s3_keys),
608-
leave=False)
609-
else:
610-
file_zip = zip(s3_keys, files[ftype])
611-
for key, fname in file_zip:
612-
_download_from_s3(fname=fname,
613-
bucket=self.study.bucket,
614-
key=key,
615-
overwrite=overwrite)
597+
for key, fname in zip(s3_keys, files[ftype]):
598+
key_file_pairs.append((key, fname))
616599
else:
617600
raise TypeError(
618601
f"This subject {self.subject_id} has {ftype} S3 keys that "
@@ -625,8 +608,35 @@ def download(self, directory, include_site=False,
625608
if not files_by_session.keys():
626609
# There were no valid sessions
627610
self._valid = False
611+
mod_logger.warning(
612+
f"Subject {self.subject_id} is not a valid subject. "
613+
f"Skipping download."
614+
)
615+
return
628616

629-
self.study.postprocess(subject=self, pbar=pbar)
617+
# Now iterate through the list and download each item
618+
if pbar:
619+
progress = tqdm(desc=f"Download {self.subject_id}",
620+
position=pbar_idx,
621+
total=len(key_file_pairs) + 1)
622+
623+
for (key, fname) in key_file_pairs:
624+
_download_from_s3(fname=fname,
625+
bucket=self.study.bucket,
626+
key=key,
627+
overwrite=overwrite)
628+
629+
if pbar:
630+
progress.update()
631+
632+
if pbar:
633+
progress.set_description(f"Postproc {self.subject_id}")
634+
635+
self.study.postprocess(subject=self)
636+
637+
if pbar:
638+
progress.update()
639+
progress.close()
630640

631641
def _determine_directions(self,
632642
input_files,

0 commit comments

Comments
 (0)