@@ -252,20 +252,18 @@ def list_all_subjects(self):
252
252
"""
253
253
raise NotImplementedError
254
254
255
- def postprocess (self , subject , pbar ):
255
+ def postprocess (self , subject ):
256
256
"""Study-specific postprocessing steps
257
257
258
258
Parameters
259
259
----------
260
260
subject : dmriprep.data.Subject
261
261
subject instance
262
-
263
- pbar : bool, default=True
264
- If True, include progress bar
265
262
"""
266
263
raise NotImplementedError
267
264
268
- def download (self , directory , include_site = False , overwrite = False ):
265
+ def download (self , directory , include_site = False , overwrite = False ,
266
+ pbar = True ):
269
267
"""Download files for each subject in the study
270
268
271
269
Parameters
@@ -279,6 +277,9 @@ def download(self, directory, include_site=False, overwrite=False):
279
277
overwrite : bool, default=False
280
278
If True, overwrite files for each subject
281
279
280
+ pbar : bool, default=True
281
+ If True, include progress bar
282
+
282
283
See Also
283
284
--------
284
285
dmriprep.data.Subject.download()
@@ -287,11 +288,11 @@ def download(self, directory, include_site=False, overwrite=False):
287
288
directory = directory ,
288
289
include_site = include_site ,
289
290
overwrite = overwrite ,
290
- pbar = False
291
- ) for sub in self .subjects ]
291
+ pbar = pbar ,
292
+ pbar_idx = idx ,
293
+ ) for idx , sub in enumerate (self .subjects )]
292
294
293
- with ProgressBar ():
294
- compute (* results , scheduler = "threads" )
295
+ compute (* results , scheduler = "threads" )
295
296
296
297
297
298
class HBN (Study ):
@@ -362,7 +363,7 @@ def get_subs_from_tsv_key(s3_key):
362
363
363
364
return all_subjects
364
365
365
- def postprocess (self , subject , pbar ):
366
+ def postprocess (self , subject ):
366
367
"""Move the T1 file back into the freesurfer directory.
367
368
368
369
This step is specific to the HBN dataset where the T1 files
@@ -373,17 +374,8 @@ def postprocess(self, subject, pbar):
373
374
----------
374
375
subject : dmriprep.data.Subject
375
376
subject instance
376
-
377
- pbar : bool, default=True
378
- If True, include progress bar
379
377
"""
380
- if pbar :
381
- sessions_pbar = tqdm (subject .files .keys (),
382
- desc = "Postprocess mriconvert T1W" )
383
- else :
384
- sessions_pbar = subject .files .keys ()
385
-
386
- for sess in sessions_pbar :
378
+ for sess in subject .files .keys ():
387
379
t1_file = subject .files [sess ]['t1w' ][0 ]
388
380
freesurfer_path = op .join (op .dirname (t1_file ), 'freesurfer' )
389
381
@@ -424,8 +416,9 @@ def __init__(self, subject_id, study, site=None):
424
416
self ._site = site
425
417
self ._valid = False
426
418
self ._organize_s3_keys ()
427
- self ._s3_keys = self ._determine_directions (self ._s3_keys )
428
- self ._files = None
419
+ if self .valid :
420
+ self ._s3_keys = self ._determine_directions (self ._s3_keys )
421
+ self ._files = None
429
422
430
423
@property
431
424
def subject_id (self ):
@@ -557,7 +550,7 @@ def _organize_s3_keys(self):
557
550
self ._s3_keys = None
558
551
559
552
def download (self , directory , include_site = False ,
560
- overwrite = False , pbar = True ):
553
+ overwrite = False , pbar = True , pbar_idx = 0 ):
561
554
"""Download files from S3
562
555
563
556
Parameters
@@ -573,7 +566,17 @@ def download(self, directory, include_site=False,
573
566
574
567
pbar : bool, default=True
575
568
If True, include download progress bar
569
+
570
+ pbar_idx : int, default=0
571
+ Progress bar index for multithreaded progress bars
576
572
"""
573
+ if not self .valid :
574
+ mod_logger .warning (
575
+ f"Subject { self .subject_id } is not a valid subject. "
576
+ f"Skipping download."
577
+ )
578
+ return
579
+
577
580
if include_site :
578
581
directory = op .join (directory , self .site )
579
582
@@ -583,36 +586,16 @@ def download(self, directory, include_site=False,
583
586
)) for p in v ] for k , v in self .s3_keys .items ()
584
587
}
585
588
586
- if pbar :
587
- pbar_ftypes = tqdm (self .s3_keys .keys (),
588
- desc = f"Downloading { self .subject_id } " )
589
- else :
590
- pbar_ftypes = self .s3_keys .keys ()
589
+ # Generate list of (key, file) tuples
590
+ key_file_pairs = []
591
591
592
- for ftype in pbar_ftypes :
593
- if pbar :
594
- pbar_ftypes .set_description (
595
- f"Downloading { self .subject_id } ({ ftype } )"
596
- )
592
+ for ftype in self .s3_keys .keys ():
597
593
s3_keys = self .s3_keys [ftype ]
598
594
if isinstance (s3_keys , str ):
599
- _download_from_s3 (fname = files [ftype ],
600
- bucket = self .study .bucket ,
601
- key = s3_keys ,
602
- overwrite = overwrite )
595
+ key_file_pairs .append ((s3_keys , files [ftype ]))
603
596
elif all (isinstance (x , str ) for x in s3_keys ):
604
- if pbar :
605
- file_zip = tqdm (zip (s3_keys , files [ftype ]),
606
- desc = f"{ ftype } " ,
607
- total = len (s3_keys ),
608
- leave = False )
609
- else :
610
- file_zip = zip (s3_keys , files [ftype ])
611
- for key , fname in file_zip :
612
- _download_from_s3 (fname = fname ,
613
- bucket = self .study .bucket ,
614
- key = key ,
615
- overwrite = overwrite )
597
+ for key , fname in zip (s3_keys , files [ftype ]):
598
+ key_file_pairs .append ((key , fname ))
616
599
else :
617
600
raise TypeError (
618
601
f"This subject { self .subject_id } has { ftype } S3 keys that "
@@ -625,8 +608,35 @@ def download(self, directory, include_site=False,
625
608
if not files_by_session .keys ():
626
609
# There were no valid sessions
627
610
self ._valid = False
611
+ mod_logger .warning (
612
+ f"Subject { self .subject_id } is not a valid subject. "
613
+ f"Skipping download."
614
+ )
615
+ return
628
616
629
- self .study .postprocess (subject = self , pbar = pbar )
617
+ # Now iterate through the list and download each item
618
+ if pbar :
619
+ progress = tqdm (desc = f"Download { self .subject_id } " ,
620
+ position = pbar_idx ,
621
+ total = len (key_file_pairs ) + 1 )
622
+
623
+ for (key , fname ) in key_file_pairs :
624
+ _download_from_s3 (fname = fname ,
625
+ bucket = self .study .bucket ,
626
+ key = key ,
627
+ overwrite = overwrite )
628
+
629
+ if pbar :
630
+ progress .update ()
631
+
632
+ if pbar :
633
+ progress .set_description (f"Postproc { self .subject_id } " )
634
+
635
+ self .study .postprocess (subject = self )
636
+
637
+ if pbar :
638
+ progress .update ()
639
+ progress .close ()
630
640
631
641
def _determine_directions (self ,
632
642
input_files ,
0 commit comments