13
13
import pandas as pd
14
14
from dask import compute , delayed
15
15
from dask .diagnostics import ProgressBar
16
+ from tqdm .auto import tqdm
16
17
17
18
18
19
mod_logger = logging .getLogger (__name__ )
@@ -203,8 +204,12 @@ def postprocess(self):
203
204
raise NotImplementedError
204
205
205
206
def download (self , directory , include_site = False , overwrite = False ):
206
- results = [delayed (sub .download )(directory , include_site , overwrite )
207
- for sub in self .subjects ]
207
+ results = [delayed (sub .download )(
208
+ directory = directory ,
209
+ include_site = include_site ,
210
+ overwrite = overwrite ,
211
+ pbar = False
212
+ ) for sub in self .subjects ]
208
213
209
214
with ProgressBar ():
210
215
compute (* results , scheduler = "threads" )
@@ -373,7 +378,8 @@ def _organize_s3_keys(self):
373
378
self ._valid = False
374
379
self ._s3_keys = None
375
380
376
- def download (self , directory , include_site = False , overwrite = False ):
381
+ def download (self , directory , include_site = False ,
382
+ overwrite = False , pbar = True ):
377
383
if include_site :
378
384
directory = op .join (directory , self .site )
379
385
@@ -383,14 +389,25 @@ def download(self, directory, include_site=False, overwrite=False):
383
389
)) for p in v ] for k , v in self .s3_keys .items ()
384
390
}
385
391
386
- for ftype , s3_keys in self .s3_keys .items ():
392
+ pbar_ftypes = tqdm (self .s3_keys .keys (),
393
+ desc = f"Downloading { self .subject_id } " )
394
+
395
+ for ftype in pbar_ftypes :
396
+ pbar_ftypes .set_description (
397
+ f"Downloading { self .subject_id } ({ ftype } )"
398
+ )
399
+ s3_keys = self .s3_keys [ftype ]
387
400
if isinstance (s3_keys , str ):
388
401
_download_from_s3 (fname = files [ftype ],
389
402
bucket = self .study .bucket ,
390
403
key = s3_keys ,
391
404
overwrite = overwrite )
392
405
elif all (isinstance (x , str ) for x in s3_keys ):
393
- for key , fname in zip (s3_keys , files [ftype ]):
406
+ file_zip = tqdm (zip (s3_keys , files [ftype ]),
407
+ desc = f"{ ftype } " ,
408
+ total = len (s3_keys ),
409
+ leave = False )
410
+ for key , fname in file_zip :
394
411
_download_from_s3 (fname = fname ,
395
412
bucket = self .study .bucket ,
396
413
key = key ,
@@ -420,7 +437,7 @@ def _determine_directions(self,
420
437
421
438
Parameters
422
439
----------
423
- input_files : InputFiles namedtuple
440
+ input_files : dict
424
441
The local input files for the subject
425
442
426
443
input_type : "s3" or "local", default="s3"
@@ -572,8 +589,8 @@ def _separate_sessions(self, input_files, multiples_policy='sessions',
572
589
573
590
Returns
574
591
-------
575
- list of InputFiles namedtuples
576
- List of InputFiles namedtuples for each session ID.
592
+ dict of dicts
593
+ Dict of Dicts of file names
577
594
"""
578
595
if multiples_policy not in ['sessions' , 'concatenate' ]:
579
596
raise ValueError ('`multiples_policy` must be either "sessions" or '
0 commit comments