-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathnemo_lesion_to_chaco.py
More file actions
1342 lines (1087 loc) · 59.5 KB
/
nemo_lesion_to_chaco.py
File metadata and controls
1342 lines (1087 loc) · 59.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import multiprocessing
import os
from pathlib import Path
import numpy as np
import nibabel as nib
import nibabel.processing
import time
import sys
from scipy import sparse
from matplotlib import pyplot as plt
from nilearn import plotting, image
from scipy import ndimage
import argparse
import tempfile
import subprocess
import boto3
import pickle
import shutil
from itertools import repeat
NUM_THREADS=None
def argument_parse(arglist):
parser=argparse.ArgumentParser(description='Read lesion mask and create voxel-wise ChaCo maps for all reference subjects')
parser.add_argument('--lesion','-l',action='store', dest='lesion')
parser.add_argument('--outputbase','-o',action='store', dest='outputbase')
parser.add_argument('--chunklist','-c',action='store', dest='chunklist')
parser.add_argument('--chunkdir','-cd',action='store', dest='chunkdir')
parser.add_argument('--refvol','-r',action='store', dest='refvol')
parser.add_argument('--endpoints','-e',action='store', dest='endpoints')
parser.add_argument('--endpointsmask','-em',action='store', dest='endpointsmask')
parser.add_argument('--asum','-a',action='store', dest='asum')
parser.add_argument('--asum_weighted','-aw',action='store', dest='asum_weighted')
parser.add_argument('--asum_cumulative','-ac',action='store', dest='asum_cumulative')
parser.add_argument('--asum_weighted_cumulative','-acw',action='store', dest='asum_weighted_cumulative')
parser.add_argument('--trackweights','-t',action='store', dest='trackweights')
parser.add_argument('--tracklengths','-tl',action='store',dest='tracklengths')
parser.add_argument('--weighted','-w',action='store_true', dest='weighted')
parser.add_argument('--smoothed','-s',action='store_true', dest='smoothed')
parser.add_argument('--smoothfwhm','-sw',default=6, action='store', dest='smoothfwhm', help = 'default: %(default)d')
parser.add_argument('--smoothmode','-sm',default='ratio', action='store', dest='smoothmode', choices=['ratio','counts'], help = 'default: %(default)s')
parser.add_argument('--s3nemoroot','-s3',action='store', dest='s3nemoroot')
parser.add_argument('--parcelvol','-p',action='append', dest='parcelvol')
parser.add_argument('--resolution','-res',action='append', dest='resolution')
parser.add_argument('--cumulative',action='store_true', dest='cumulative')
parser.add_argument('--pairwise',action='store_true', dest='pairwise')
parser.add_argument('--continuous_value',action='store_true', dest='continuous_value')
parser.add_argument('--tracking_algorithm',action='store',dest='tracking_algorithm')
parser.add_argument('--debug',action='store_true', dest='debug')
parser.add_argument('--subjcount',action='store', dest='subjcount', type=int, help='number of reference subjects to compute (for debugging only!)')
parser.add_argument('--onlynonzerodenom',action='store_true',dest='only_nonzero_denom', help='only include subjects with non-zero denominator for a given voxel')
parser.add_argument('--numthreads',action='store',dest='num_threads', type=int, help='How many threads to use for execution (default=all)')
return parser.parse_args(arglist)
def get_available_cpus():
if NUM_THREADS is not None and NUM_THREADS > 0:
return NUM_THREADS
#return multiprocessing.cpu_count()
try:
return len(os.sched_getaffinity(0))
except AttributeError:
#fallback for non-Linux
return int(os.environ.get("SLURM_CPUS_PER_TASK", os.cpu_count()))
def durationToString(numseconds):
if numseconds < 60:
return "%.3f seconds" % (numseconds)
newms = numseconds % 1
newseconds = int(numseconds) % 60
newminutes = int(numseconds / 60) % 60
newhours = int(numseconds / (60*60)) % 24
newdays = int(numseconds / (60*60*24))
newstring=""
if newdays > 0:
newstring+="%gd" % (newdays)
if newhours > 0:
newstring+="%gh" % (newhours)
if newminutes > 0:
newstring+="%gm" % (newminutes)
if newms > 0:
newstring+="%.3fs" % (newseconds+newms)
elif newseconds > 0:
newstring+="%gs" % (newseconds)
return newstring
def createSparseDownsampleParcellation(newvoxmm, origvoxmm, volshape, refimg):
#chunksize=newvoxmm*newvoxmm*newvoxmm
chunkvec_x=np.int32(np.floor(np.arange(volshape[0])/newvoxmm))
chunkvec_y=np.int32(np.floor(np.arange(volshape[1])/newvoxmm))
chunkvec_z=np.int32(np.floor(np.arange(volshape[2])/newvoxmm))
chunkvec_size=(chunkvec_x[-1]+1, chunkvec_y[-1]+1, chunkvec_z[-1]+1)
chunky,chunkx,chunkz=np.meshgrid(chunkvec_y,chunkvec_x,chunkvec_z)
#a volsize 3D array where each entry is a 0-numchunks index
chunkidx=chunkz + chunky*chunkvec_size[0] + chunkx*chunkvec_size[0]*chunkvec_size[1]
#a voxidx x 1 array where chunkidx_flat(voxidx)=chunk index
chunkidx_flat=chunkidx.flatten()
numchunks=np.max(chunkidx)+1
newvolshape=np.ceil(np.array(volshape)/newvoxmm).astype(np.int32)
numvoxels=np.prod(volshape)
newnumvoxels=np.prod(newvolshape)
unique_chunks, uidx =np.unique(chunkidx_flat, return_inverse=True)
Psparse=sparse.csr_matrix((np.ones(numvoxels),(range(numvoxels),uidx)),shape=(numvoxels,numchunks),dtype=np.float32)
newaff=refimg.affine.copy()
newaff[:3,:3]*=newvoxmm/origvoxmm
#because voxel center is 0.5 in orig and 0.5*res in the new one, we need to add a small shift to the new reference volume so it properly overlays
voxoffset=(newvoxmm-origvoxmm)/2.0
newaff[:3,-1]+=np.sign(refimg.affine[:3,:3]) @ [voxoffset,voxoffset,voxoffset]
newrefimg=nib.processing.resample_from_to(refimg,(newvolshape,newaff),order=0)
return Psparse, newvolshape, newrefimg
def flatParcellationToTransform(Pflat, isubj=None, out_type="csr", max_sequential_roi_value=None):
if sparse.issparse(Pflat):
Pdata=Pflat[isubj,:].toarray().flatten()
elif isubj is None:
Pdata=Pflat.flatten()
elif len(Pdata.shape)==2:
Pdata=Pflat[isubj,:].flatten()
numvoxels=np.prod(Pdata.shape)
pmaskidx=np.where(Pdata!=0)[0]
uroi, uidx=np.unique(Pdata[Pdata!=0],return_inverse=True)
numroi=len(uroi)
if max_sequential_roi_value is not None:
#this would create an entry at the actual ROI values, rather than just going through the sequential PRESENT value
#eg: for cc400 it would be a 7M x 400 array instead of 7M x 392
# but for an arbitrary/custom input, where they left freesurfer values, this could make it in the thousands!
uidx=(uroi[uidx]-1).astype(np.int64)
numroi=max_sequential_roi_value.astype(np.int64)
if out_type == "csr":
return sparse.csr_matrix((np.ones(pmaskidx.size),(pmaskidx,uidx)),shape=(numvoxels,numroi),dtype=np.float32)
elif out_type == "csc":
return sparse.csc_matrix((np.ones(pmaskidx.size),(pmaskidx,uidx)),shape=(numvoxels,numroi),dtype=np.float32)
def checkVolumeShape(Pimg, refimg, filename_display, expected_shape, expected_shape_spm):
imgshape=Pimg.shape
if len(imgshape)>=4 and all([x==1 for x in imgshape[3:]]):
#some nii files include a 4th dimension even for single volumes (eg: outputs from nifti_4dfp)
#if all dimensions>3 are just 1, just flatten and reshape to correct 3D
Pimg=nib.Nifti1Image(np.reshape(Pimg.get_fdata().flatten()[:np.prod(imgshape)],imgshape[:3]),affine=Pimg.affine,header=Pimg.header)
imgshape=Pimg.shape
if imgshape == expected_shape:
#seems correct
#pass
#resample no matter in case there's some LPI vs RPI issue
Pimg=nibabel.processing.resample_from_to(Pimg,refimg,order=0)
elif imgshape == expected_shape_spm:
#print('%s was 181x217x181, not the expected 182x218x181. Assuming SPM-based reg and padding end of each dim.' % (filename_display))
#Pdata=np.pad(Pdata,(0,1),mode='constant')
print('%s was 181x217x181, not the expected 182x218x182. Resampling to expected.' % (filename_display))
Pimg=nibabel.processing.resample_from_to(Pimg,refimg,order=0)
else:
shapestr=",".join([str(x) for x in Pimg.shape])
raise(Exception('Unexpected volume size: (%s) for %s. Each input must be a SINGLE volume registered to 182x218x182 MNIv6 template (FSL template)' % (shapestr,filename_display)))
return Pimg
def loadParcellation(filename, numroi=None, refimg=None, expected_shape=None, expected_shape_spm=None):
if(filename.lower().endswith(".npz")):
dostart=time.time()
if filename.lower().endswith(".masked.npz") or filename.lower().endswith(".dense.npz"):
Pdata = np.load(filename, allow_pickle=True)
if 'mask' in Pdata:
Pmask = Pdata['mask']
Pmasked = Pdata['data']
Psparse_allsubj=np.zeros((Pmasked.shape[0],Pmask.shape[1]),dtype=Pmasked.dtype)
Psparse_allsubj[:,Pmask.ravel()>0]=Pmasked
Psparse_allsubj=sparse.csr_matrix(Psparse_allsubj)
else:
Psparse_allsubj=sparse.csr_matrix(Pdata['data'])
else:
Psparse_allsubj=sparse.load_npz(filename)
numroisubj=Psparse_allsubj.shape[0]
max_seq_roi_val=Psparse_allsubj.max()
if numroi is not None:
max_seq_roi_val=np.array(numroi)
#store these as csc for memory efficiency (have to convert each subject later)
global flat2sparse
def flat2sparse(isubj):
return flatParcellationToTransform(Psparse_allsubj, isubj, out_type="csc", max_sequential_roi_value=max_seq_roi_val)
num_cpu=get_available_cpus()
multiproc_cores=num_cpu-1
P=multiprocessing.Pool(multiproc_cores)
Psparse=P.map(flat2sparse,range(numroisubj))
P.close()
del flat2sparse
Psparse_allsubj=None
numroi=Psparse[0].shape[1]
elif(filename.lower().endswith(".pkl")):
#this file type assumes a list of voxel x ROI Psparse transform matrices
Psparse=pickle.load(open(filename,"rb"))
numroisubj=len(Psparse)
numroi=Psparse[0].shape[1]
numvoxels=Psparse[0].shape[0]
else:
Pimg=nib.load(filename)
Pimg = checkVolumeShape(Pimg, refimg, filename.split("/")[-1], expected_shape, expected_shape_spm)
Pdata=Pimg.get_fdata()
max_seq_roi_val=None
if numroi is not None:
max_seq_roi_val=np.array(numroi)
Psparse = flatParcellationToTransform(Pdata.flatten(), None, out_type="csr", max_sequential_roi_value=max_seq_roi_val)
numroi = Psparse.shape[1]
return Psparse
def smooth_sparse_vol(sparsevals, fwhm, volshape, voxmm):
outsmooth=sparse.csr_matrix(ndimage.gaussian_filter(np.reshape(np.array(sparsevals.todense()),volshape),sigma=fwhm/2.35482/voxmm).flatten())
outsmooth.eliminate_zeros()
return outsmooth
############################################################
############################################################
############################################################
# multiprocessing.map functions
# * only take a single iterable input
# * each requires certain externally defined READ-ONLY variables
# (trying to avoid passing these as additional arguments because they are large
# and I'm worried they will take up additional memory in each subprocess.
# In Linux+Mac, this should work just fine. In Windows, this might fail)
# * Note these external variables can't be defined inside a function, so "main" has
# to be inside an "if" but not a "def main():"
#original:
#Mapping to endpoints and ChaCo took 4.252 seconds on 15 threads
#Mapping to conn took 2.511 seconds on 15 threads
#now with T.eliminate_zeros() in map_to_endpoints() also:
#Mapping to endpoints and ChaCo took 2.222 seconds on 15 threads
#Mapping to conn took 2.420 seconds on 15 threads
###########################################################
# make a per process s3_client
nemo_s3_client = None
def s3initialize():
global nemo_s3_client
nemo_s3_client = boto3.client('s3')
def s3download(job):
bucket, key, filename = job
nemo_s3_client.download_file(bucket,key,filename)
def save_lesion_chunk(whichchunk):
#externals: chunkfile_fmt, tmpchunkfile_fmt, Lmask, chunkidx_flat, numsubj, do_cumulative_hits, chunksize
subjchunksA=sparse.load_npz(chunkfile_fmt % (whichchunk))
Lchunk=Lmask[chunkidx_flat==whichchunk]
#some chunks around the edge of the volume don't have the full 1000 voxels!
#make sure we get the size of this specific chunk so matrix multiplications agree
chunksize_thischunk=len(Lchunk)
chunktrackmask=[]
for isubj in range(numsubj):
#binarize the T matrix (each streamline is hit or not) here
#If we remove this, we have to figure out the denominator for ChaCo (currently Asum = total number of streamlines at each endpoint)
if do_cumulative_hits:
chunktrackmask.append(sparse.csr_matrix(Lchunk @ subjchunksA[(isubj*chunksize_thischunk):((isubj+1)*chunksize_thischunk),:]))
else:
chunktrackmask.append(sparse.csr_matrix(Lchunk @ subjchunksA[(isubj*chunksize_thischunk):((isubj+1)*chunksize_thischunk),:])>0)
chunktrackmask[-1].eliminate_zeros()
tmpfilename=tmpchunkfile_fmt % (whichchunk)
sparse.save_npz(tmpfilename,sparse.vstack(chunktrackmask),compressed=False)
return whichchunk
def map_to_endpoints(isubj):
#externals: endpointmat, numsubj, tidx, numtracks, numvoxels, do_weighted, T_allsubj, trackweights, Asum, tmpdir
endpt=endpointmat[(isubj,isubj+numsubj),:].flatten()
#note: create this as a float32 so later mult properly SUMS columns instead of just logical
B=sparse.csr_matrix((np.ones(tidx.shape,dtype=np.float32),(tidx,endpt)),shape=(numtracks,numvoxels))
if do_weighted:
chacovol=((T_allsubj[isubj,:]).multiply(trackweights[isubj,:]) @ B).multiply(Asum[isubj,:])
else:
chacovol=(T_allsubj[isubj,:] @ B).multiply(Asum[isubj,:])
#any endpoints for "voxel 0" are from spurious endpoints for "lost" streamlines
chacovol[0]=0
chacovol.eliminate_zeros()
chacofile_subj=tmpdir+'/chacovol_subj%05d.npz' % (isubj)
sparse.save_npz(chacofile_subj,chacovol,compressed=False)
#if Psparse_list:
# for Psparse in Psparse_list:
# #chacovol_parc=((T_allsubj[isubj,:]>0) @ (B @ Psparse)).multiply(Asum[isubj,:] @ Psparse))
def map_to_endpoints_numerator(isubj):
#externals: endpointmat, numsubj, tidx, numtracks, numvoxels, T_allsubj, trackweights, tmpdir, Psparse_list
endpt=endpointmat[(isubj,isubj+numsubj),:]
endpt_iszero=np.any(endpt==0,axis=0)
#note: create this as a float32 so later mult properly SUMS columns instead of just logical
B=sparse.csr_matrix((np.ones(tidx.shape,dtype=np.float32),(tidx,endpt.flatten())),shape=(numtracks,numvoxels))
T=T_allsubj[isubj,:].astype(np.float32)
if trackweights is not None:
T.data*=trackweights[isubj,T.indices]
T.data[endpt_iszero[T.indices]]=0
chacovol=T @ B
#any endpoints for "voxel 0" are from spurious endpoints for "lost" streamlines
#!!!!!! chacovol[0]=0
chacovol.eliminate_zeros()
chacofile_subj=tmpdir+'/chacovol_subj%05d.npz' % (isubj)
sparse.save_npz(chacofile_subj,chacovol,compressed=False)
#Need to get an Asum and Asum_weighted that accounts for tracklengths!
#if do_compute_denom:
# if do_cumulative_hits:
# denom_val=Asum[isubj,:].copy()
# else:
# denom_val=Asum[isubj,:].copy()
# something;
if Psparse_list:
for iparc, Pdict in enumerate(Psparse_list):
if isinstance(Pdict['transform'],list):
#stored as csc to need to transpose
chacovol_parc=chacovol @ Pdict['transform'][isubj].tocsr()
else:
chacovol_parc=chacovol @ Pdict['transform']
chacovol_parc.eliminate_zeros()
chacofile_subj=tmpdir+'/chacovol_parc%05d_subj%05d.npz' % (iparc,isubj)
sparse.save_npz(chacofile_subj,chacovol_parc,compressed=False)
###########################################################
def map_to_endpoints_conn(isubj):
#externals: endpointmat, numsubj, T_allsubj, trackweights, numvoxels, tmpdir, do_save_fullconn, do_compute_denom, do_cumulative_hits, tracklengths, Psparse_list
endpt=endpointmat[(isubj,isubj+numsubj),:]
endpt1=endpt.min(axis=0)
endpt2=endpt.max(axis=0)
#any endpoints in "voxel 0" are from spurious endpoints for "lost" streamlines
endpt_iszero=(endpt1==0) | (endpt2==0)
#chacoconn=sparse.csr_matrix(((T_allsubj[isubj,:]>0).toarray().flatten(),(endpt1,endpt2)),shape=(numvoxels,numvoxels),dtype=np.float32)
#note: need to cast to non-bool here otherwise the summing in sparse matrix build doesn't work!
T=T_allsubj[isubj,:].astype(np.float32)
if trackweights is not None:
T.data*=trackweights[isubj,T.indices]
T.data[endpt_iszero[T.indices]]=0
#T.eliminate_zeros()
chacoconn=sparse.csr_matrix((T.data,(endpt1[T.indices],endpt2[T.indices])),shape=(numvoxels,numvoxels),dtype=np.float32)
#sparse.save_npz(tmpdir+'/chacoconnAconnsum_subj%05d.npz' % (isubj),Aconnsum,compressed=False)
chacovol=sparse.csr_matrix(chacoconn.sum(axis=0)+sparse.triu(chacoconn,k=1).sum(axis=1).T)
#!!!!! chacovol[0]=0
chacovol.eliminate_zeros()
chacofile_subj=tmpdir+'/chacovol_subj%05d.npz' % (isubj)
sparse.save_npz(chacofile_subj,chacovol,compressed=False)
#might not want the full voxel x voxel connectivity matrix (probably don't!)
if do_save_fullconn:
chacofile_subj=tmpdir+'/chacoconn_subj%05d.npz' % (isubj)
#sparse.save_npz(chacofile_subj,chacoconn.multiply(Aconnsum),compressed=False)
sparse.save_npz(chacofile_subj,chacoconn,compressed=False)
#chacoconn file is 40MB per subject (*420=16.8GB) for 375chunk lesion, 28MB per subject (*420=11.8GB) for the smallest lesion
#takes about 2.7x as long as simple chacovol
#we might want to use a PREcomputed Aconnsum denominator. Bigger file (12GB) but faster calculation
#however computing it here is more flexible IF we want to compute
if do_compute_denom:
if trackweights is None:
denom_val=np.ones(endpt1.size,dtype=np.float32)
else:
denom_val=trackweights[isubj,:].copy()
if do_cumulative_hits:
#denominator should assume hits along ENTIRE length in this case
denom_val*=tracklengths[isubj,:]
denom_val[endpt_iszero]=0
#only need to store denominator when numerator (T) is non-zero
#actually no! since we use this for parcellation we will lose all the other voxels in the parcel and the parcellated ratios will all be ~1!
#Aconnsum=sparse.csr_matrix((denom_val[T.indices],(endpt1[T.indices],endpt2[T.indices])),shape=(numvoxels,numvoxels),dtype=np.float32)
Aconnsum=sparse.csr_matrix((denom_val,(endpt1,endpt2)),shape=(numvoxels,numvoxels),dtype=np.float32)
Aconnsum.eliminate_zeros()
if do_save_fullconn:
chacofile_subj=tmpdir+'/chacoconn_denom_subj%05d.npz' % (isubj)
sparse.save_npz(chacofile_subj,Aconnsum,compressed=False)
#compute the full voxelwise denom here since it doesn't take that long relative to other
#steps and we dont have to worry about precomputing every combination of flavors
chacovol_denom=sparse.csr_matrix(Aconnsum.sum(axis=0)+sparse.triu(Aconnsum,k=1).sum(axis=1).T)
chacovol_denom.eliminate_zeros()
chacofile_subj=tmpdir+'/chacovol_denom_subj%05d.npz' % (isubj)
sparse.save_npz(chacofile_subj,chacovol_denom,compressed=False)
if Psparse_list:
for iparc, Pdict in enumerate(Psparse_list):
if isinstance(Pdict['transform'],list):
#stored as csc to need to transpose
Psparse=Pdict['transform'][isubj].tocsr()
else:
Psparse=Pdict['transform']
chacoconn_parc=Psparse.T.tocsr() @ chacoconn @ Psparse
#Make parcellated/downsampled outputs upper triangular (keeping diagonal)
chacoconn_parc=sparse.triu(chacoconn_parc,k=0)+sparse.tril(chacoconn_parc,k=-1).T
#sparse.save_npz(chacofile_subj,chacoconn.multiply(Aconnsum),compressed=False)
chacofile_subj=tmpdir+'/chacoconn_parc%05d_subj%05d.npz' % (iparc,isubj)
sparse.save_npz(chacofile_subj,chacoconn_parc,compressed=False)
#chacovol_parc=((T_allsubj[isubj,:]>0) @ (B @ Psparse)).multiply(Asum[isubj,:] @ Psparse))
#kval=0 means keep self-self entries (pairwise diagonals) when computing regional scores
#kval=1 means remove that diagonal before computing regional score
#pairwise (chacoconn) outputs remain unchanged
chacovol_keepdiag_kval=1 #exclude diag by default
if Pdict['keepdiag']:
chacovol_keepdiag_kval=0
chacovol_parc=sparse.csr_matrix(sparse.triu(chacoconn_parc,k=chacovol_keepdiag_kval).sum(axis=0)+sparse.triu(chacoconn_parc,k=1).sum(axis=1).T)
chacovol_parc.eliminate_zeros()
chacofile_subj=tmpdir+'/chacovol_parc%05d_subj%05d.npz' % (iparc,isubj)
sparse.save_npz(chacofile_subj,chacovol_parc,compressed=False)
if do_compute_denom:
Aconnsum_parc=Psparse.T.tocsr() @ Aconnsum @ Psparse
Aconnsum_parc.eliminate_zeros()
#Make parcellated/downsampled outputs upper triangular (keeping diagonal)
Aconnsum_parc=sparse.triu(Aconnsum_parc,k=0)+sparse.tril(Aconnsum_parc,k=-1).T
chacofile_subj=tmpdir+'/chacoconn_parc%05d_denom_subj%05d.npz' % (iparc,isubj)
sparse.save_npz(chacofile_subj,Aconnsum_parc,compressed=False)
chacovol_parc_denom=sparse.csr_matrix(sparse.triu(Aconnsum_parc,k=chacovol_keepdiag_kval).sum(axis=0)+sparse.triu(Aconnsum_parc,k=1).sum(axis=1).T)
chacovol_parc_denom.eliminate_zeros()
chacofile_subj=tmpdir+'/chacovol_parc%05d_denom_subj%05d.npz' % (iparc,isubj)
sparse.save_npz(chacofile_subj,chacovol_parc_denom,compressed=False)
###########################################################
def parcellation_to_volume(parcdata, parcvol):
parcmask=parcvol!=0
uparc,uparc_idx=np.unique(parcvol[parcmask],return_inverse=True)
if parcdata.shape[0] == len(uparc):
pass
elif parcdata.shape[1] == len(uparc):
parcdata=parcdata.T
elif parcdata.shape[0] >= max(uparc):
#this happens if input is cifti91k (full 0-91282) and parcvol does not have all of those indices
parcdata=parcdata[uparc.astype(np.uint32)-1,:]
elif parcdata.shape[1] >= max(uparc):
#this happens if input is cifti91k (full 0-91282) and parcvol does not have all of those indices
parcdata=parcdata[:,uparc.astype(np.uint32)-1].T
else:
print("Parcellated data dimensions do not match parcellation")
return None
newvol=np.zeros(parcvol.shape)
newvol[parcmask]=np.mean(parcdata[uparc_idx],axis=1)
return newvol
def make_triangular_matrix_symmetric(m):
has_triu=np.any(np.triu(m!=0,1))
has_tril=np.any(np.tril(m!=0,-1))
if has_triu and not has_tril:
m+=np.triu(m,1).T
elif has_tril and not has_triu:
m+=np.tril(m,-1).T
return m
def save_chaco_output(chaco_output, delete_files=True):
#externals: Psparse_list, NUMBER_OF_SUBJECTS_TO_COMPUTE, tmpdir, Asum, Aconnsum, do_debug, outputbase,
#for chaco_output in chaco_output_list:
#print(chaco_output)
chaco_allsubj=[]
chaco_denom_allsubj=[]
do_nonzero_denom=False
nonzero_denom_thresh=None
if 'only_nonzero_denom' in chaco_output and chaco_output['only_nonzero_denom']:
do_nonzero_denom=True
if 'nonzero_denom_thresh' in chaco_output:
nonzero_denom_thresh=chaco_output['nonzero_denom_thresh']
Psparse=None
output_reshape=chaco_output['reshape']
if chaco_output['parcelindex'] is not None:
Psparse=Psparse_list[chaco_output['parcelindex']]['transform']
starttime_accum=time.time()
#for isubj in range(numsubj):
for isubj in range(NUMBER_OF_SUBJECTS_TO_COMPUTE):
chacofile_subj=tmpdir+'/'+chaco_output['numerator'] % (isubj)
chacofile_subj_denom=None
chaco_numer=sparse.load_npz(chacofile_subj)
if isinstance(Psparse,list):
Ptmp=Psparse[isubj].tocsr()
else:
Ptmp=Psparse
if chaco_output['denominator'] == 'Asum':
if Ptmp is None:
chaco_denom=Asum[isubj,:]
else:
chaco_denom=Asum[isubj,:] @ Ptmp
#DON'T zero the denominator when numer is zero, because
#we need the original denominator intact for nemoSC
#chaco_denom = chaco_denom.multiply(chaco_numer>0)
chaco_denom.eliminate_zeros()
elif chaco_output['denominator'] == 'Aconnsum':
if Ptmp is None:
chaco_denom=Aconnsum[isubj]
else:
chaco_denom=Ptmp.T.tocsr() @ Aconnsum[isubj] @ Ptmp
#DON'T zero the denominator when numer is zero, because
#we need the original denominator intact for nemoSC
#chaco_denom = chaco_denom.multiply(chaco_numer>0)
chaco_denom.eliminate_zeros()
else:
chacofile_subj_denom=tmpdir+'/'+chaco_output['denominator'] % (isubj)
chaco_denom=sparse.load_npz(chacofile_subj_denom)
chaco_denom_allsubj.append(chaco_denom.copy())
chaco_denom.data=1/chaco_denom.data.astype(np.float32)
chaco_allsubj.append(chaco_numer.multiply(chaco_denom))
if delete_files:
os.remove(chacofile_subj)
if chacofile_subj_denom is not None:
os.remove(chacofile_subj_denom)
if do_debug:
print('Loading in %s took %s' % (chaco_output['name'],durationToString(time.time()-starttime_accum)))
# if do_nonzero_denom:
# in this mode, for each voxel/parcel/pairwise connection, we compute the chaco ratio (lesion streamlines / total streamlines),
# but IGNORE subjects that have 0 streamlines in the denominator for that location
# else:
# in standard mode, we compute the chaco ratio for each location (lesion streamlines/total streamlines) for each subject and simply average across subjects
# If a subject does not have any streamlines in a location (denom = total streamlines=0), their chaco ratio for that location = 0
# and this 0 is included in the average across all subjects
#generate an output that gives the fraction of subjects for which the denominator was non-zero
#this can be used to mask the chacomean output to exclude regions with inconsistent denominators
chacomean_denom_binfrac=None
# compute mean and stdev of chaco scores across all reference subjects
if chaco_allsubj[0].shape[0] == 1:
#stackable (for 1D chacovol)
chaco_allsubj=sparse.vstack(chaco_allsubj)
chaco_denom_allsubj=sparse.vstack(chaco_denom_allsubj)
if do_nonzero_denom:
chaconzd_numer=np.array(np.sum(chaco_allsubj,axis=0))
chaconzd_denom=np.array(np.sum(chaco_denom_allsubj>0,axis=0))
chaconzd_sqnumer=np.array(np.sum(chaco_allsubj.multiply(chaco_allsubj),axis=0))
chaconzd_mask=(chaconzd_numer>0) & (chaconzd_denom>0)
chaconzd_mean=np.zeros_like(chaconzd_numer)
chaconzd_sqmean=np.zeros_like(chaconzd_numer)
chaconzd_mean[chaconzd_mask]=chaconzd_numer[chaconzd_mask]/chaconzd_denom[chaconzd_mask]
chaconzd_sqmean[chaconzd_mask]=chaconzd_sqnumer[chaconzd_mask]/chaconzd_denom[chaconzd_mask]
chaconzd_std=np.sqrt(np.clip(chaconzd_sqmean-chaconzd_mean**2,0,None))
chacomean_denom_binfrac=chaconzd_denom/chaco_allsubj.shape[0]
chacomean=chaconzd_mean
chacostd=chaconzd_std
else:
chacomean=np.array(np.mean(chaco_allsubj,axis=0))
chacostd=np.sqrt(np.clip(np.array(np.mean(chaco_allsubj.multiply(chaco_allsubj),axis=0) - chacomean**2),0,None))
else:
#non-stackable (for list of 2D chacoconn)
chacomean=0
chacosqmean=0
for ch in chaco_allsubj:
chacomean+=ch
chacosqmean+=ch.multiply(ch)
if do_nonzero_denom:
denom=0
for chd in chaco_denom_allsubj:
denom+=(chd>0).astype(np.float32)
chacomean_denom_binfrac=denom/len(chaco_denom_allsubj)
#need to invert the denom to use .multiply for element-wise division
denom.data=1.0/denom.data
chacomean=chacomean.multiply(denom)
chacosqmean=chacosqmean.multiply(denom)
else:
chacomean/=len(chaco_allsubj)
chacosqmean/=len(chaco_allsubj)
chacostd=chacosqmean - chacomean.multiply(chacomean)
chacostd[chacostd<0]=0
chacostd.eliminate_zeros()
chacostd=np.sqrt(chacostd)
#this sqrt can be negative sometimes!
#assuming it's just a numerical precision thing and set it to 0
if sparse.issparse(chacostd):
chacostd.data[np.isnan(chacostd.data)]=0
else:
chacostd[np.isnan(chacostd)]=0
#threshold the chacomean by denom_binfrac if threshold is provided
if chacomean_denom_binfrac is not None and nonzero_denom_thresh is not None and nonzero_denom_thresh>0:
chacomean=chacomean.multiply(chacomean_denom_binfrac>nonzero_denom_thresh)
chacostd=chacostd.multiply(chacomean_denom_binfrac>nonzero_denom_thresh)
outfile_pickle=outputbase+'_'+chaco_output['name']+"_allref.pkl"
pickle.dump(chaco_allsubj, open(outfile_pickle,"wb"))
outfile_pickle=outputbase+'_'+chaco_output['name']+"_allref_denom.pkl"
pickle.dump(chaco_denom_allsubj,open(outfile_pickle,"wb"))
if output_reshape is None:
pickle.dump(chacomean, open(outputbase+'_'+chaco_output['name']+'_mean.pkl', "wb"))
pickle.dump(chacostd, open(outputbase+'_'+chaco_output['name']+'_stdev.pkl', "wb"))
if chacomean_denom_binfrac is not None:
pickle.dump(chacomean_denom_binfrac, open(outputbase+'_'+chaco_output['name']+'_denomfrac.pkl', "wb"))
else:
outimg=nib.Nifti1Image(np.reshape(np.array(chacomean),output_reshape.shape),affine=output_reshape.affine, header=output_reshape.header)
nib.save(outimg,outputbase+'_%s_mean.nii.gz' % (chaco_output['name']))
outimg=nib.Nifti1Image(np.reshape(np.array(chacostd),output_reshape.shape),affine=output_reshape.affine, header=output_reshape.header)
nib.save(outimg,outputbase+'_%s_stdev.nii.gz' % (chaco_output['name']))
if chacomean_denom_binfrac is not None:
outimg=nib.Nifti1Image(np.reshape(np.array(chacomean_denom_binfrac),output_reshape.shape),affine=output_reshape.affine, header=output_reshape.header)
nib.save(outimg,outputbase+'_%s_denomfrac.nii.gz' % (chaco_output['name']))
if do_debug:
print('Saving %s took %s' % (chaco_output['name'],durationToString(time.time()-starttime_accum)))
############################################################
############################################################
############################################################
# main command-line interface
if __name__ == "__main__":
args=argument_parse(sys.argv[1:])
if args.continuous_value:
args.cumulative=True
lesionfile=args.lesion
outputbase=args.outputbase
chunklistfile=args.chunklist
chunkdir=args.chunkdir
refimgfile=args.refvol
endpointfile=args.endpoints
endpointmaskfile=args.endpointsmask
asumfile=args.asum
asumweightedfile=args.asum_weighted
asumcumfile=args.asum_cumulative
asumweightedcumfile=args.asum_weighted_cumulative
trackweightfile=args.trackweights
do_weighted=args.weighted
do_smooth=args.smoothed
smoothing_fwhm=args.smoothfwhm
smoothing_mode=args.smoothmode
s3nemoroot=args.s3nemoroot
parcelfiles=args.parcelvol
new_resolution=args.resolution
tracklengthfile=args.tracklengths
do_cumulative_hits=args.cumulative
do_pairwise=args.pairwise
do_continuous=args.continuous_value
do_debug=args.debug
debug_subjcount=args.subjcount
tracking_algorithm=args.tracking_algorithm
do_only_include_nonzero_subjects=args.only_nonzero_denom
numthread_arg=args.num_threads
if numthread_arg is not None:
NUM_THREADS=numthread_arg
print("Executed with the following inputs:")
for k,v in vars(args).items():
if v is None:
#assume unspecified
continue
if v is False:
#assume store_true was unspecified
continue
if v is True:
#assume store_true
print("--%s" % (k))
continue
if isinstance(v,list):
for vv in v:
print("--%s" % (k), vv)
continue
print("--%s" % (k), v)
print("")
do_force_redownload = False
do_download_nemofiles = False
do_save_fullvol = False
do_save_fullconn = False
do_compute_denom = True
if do_weighted:
if do_cumulative_hits:
asumfile=asumweightedcumfile
else:
asumfile=asumweightedfile
else:
if do_cumulative_hits:
asumfile=asumcumfile
else:
asumfile=asumfile
if s3nemoroot:
do_download_nemofiles = True
s3nemoroot=s3nemoroot.replace("s3://","").replace("S3://","")
s3nemoroot_bucket=s3nemoroot.split("/")[0]
s3nemoroot_prefix="/".join(s3nemoroot.split("/")[1:])
if s3nemoroot_prefix:
s3nemoroot_prefix+="/"
if do_download_nemofiles:
starttime_download_nemofiles=time.time()
nemofiles_to_download=[asumfile]
if do_weighted:
#nemofiles_to_download=['nemo_Asum_weighted_endpoints.npz','nemo_siftweights.npy']
nemofiles_to_download.extend([trackweightfile])
if do_cumulative_hits:
nemofiles_to_download.extend([tracklengthfile])
if endpointmaskfile:
nemofiles_to_download.extend([endpointmaskfile])
#nemofiles_to_download.extend(['nemo_endpoints.npy','nemo_chunklist.npz',refimgfile])
nemofiles_to_download.extend([endpointfile,chunklistfile,refimgfile])
#check if we've already downloaded them (might be a multi-file run)
if not do_force_redownload:
nemofiles_to_download=[f for f in nemofiles_to_download if not os.path.exists(f)]
if len(nemofiles_to_download) > 0:
print('Downloading NeMo data files', end='', flush=True)
num_cpu=get_available_cpus()
multiproc_cores=num_cpu-1
P=multiprocessing.Pool(multiproc_cores, s3initialize)
jobs = [(s3nemoroot_bucket,s3nemoroot_prefix+k.split("/")[-1],k) for k in nemofiles_to_download]
try:
P.map(s3download,jobs)
except Exception as e:
print('Download failed:',e)
P.terminate()
P.close()
sys.exit(1)
P.close()
print(' took %.3f seconds' % (time.time()-starttime_download_nemofiles))
try:
smoothing_fwhm=float(smoothing_fwhm)
except ValueError:
do_smooth=False
smoothing_fwhm=0.
if smoothing_fwhm <= 0:
do_smooth=False
outputdir=Path(outputbase).parent.as_posix()
outputbase_file=Path(outputbase).name
print('Lesion file: %s' % (lesionfile))
print('Output basename: %s' % (outputbase))
print('Track weighting: ', do_weighted)
print('Cumulative track hits: ', do_cumulative_hits)
print('Continuous-valued lesion volume: ', do_continuous)
print('Output pairwise connectivity: ', do_pairwise)
print('Only include non-zero denom subjects: ', do_only_include_nonzero_subjects)
starttime=time.time()
chunklist=np.load(chunklistfile)
volshape=chunklist['volshape']
chunksize=chunklist['chunksize']
chunkidx_flat=chunklist['chunkidx_flat']
subjects=chunklist['subjects']
numtracks=chunklist['numtracks']
unique_chunks=chunklist['unique_chunks']
refimg=nib.load(refimgfile)
Limg=nib.load(lesionfile)
expected_shape=(182,218,182)
expected_shape_spm=(181,217,181)
Limg = checkVolumeShape(Limg, refimg, lesionfile.split("/")[-1], expected_shape, expected_shape_spm)
Ldata=Limg.get_fdata()
Ldata[np.isnan(Ldata)]=0 #make sure there aren't any nans that throw off mask creation
voxmm=np.sqrt(Limg.affine[:3,0].dot(Limg.affine[:3,0]))
Ldata_max=np.max(np.abs(Ldata))
if Ldata_max != 0:
Ldata=Ldata/Ldata_max #this will be useful if we do a continuous-valued version later
else:
raise(Exception('Input lesion mask is all zeros!'))
if do_continuous:
Ldata=Ldata.astype(np.float32)
else:
#remember to change nemo_save_average_glassbrain.py --binarize option if we change this!
Ldata=Ldata!=0
Lmask=Ldata.flatten()
##################
Psparse_list=[]
origvoxmm=1
origres_name="res%dmm" % (origvoxmm)
if new_resolution:
volshape=refimg.shape
for r in new_resolution:
r_pairwise=do_pairwise
r_keepdiag=False
if r.find("?") >= 0:
#handle ?nopairwise option
r_opts=r.split("?")[1:]
r=r.split("?")[0]
if "nopairwise" in r_opts:
r_pairwise=False
if "keepdiag" in r_opts:
r_keepdiag=True
if r.find("=") >= 0:
[r,rname]=r.split("=")
newvoxmm=round(abs(float(r)))
else:
newvoxmm=round(abs(float(r)))
rname="res%dmm" % (newvoxmm)
r_pairwisestr=""
if r_pairwise:
r_pairwisestr=" with pair-wise chacoconn"
if r_keepdiag:
r_pairwisestr+=", including diagonal in voxel-wise chacovol"
if newvoxmm <= 1:
do_save_fullvol=True
do_save_fullconn=r_pairwise
if rname:
origres_name=rname
print('Output will include resolution %.1fmm (volume dimensions = %dx%dx%d)%s.' % (newvoxmm,volshape[0],volshape[1],volshape[2],r_pairwisestr))
continue
Psparse, newvolshape, newrefimg = createSparseDownsampleParcellation(newvoxmm, origvoxmm, volshape, refimg)
Psparse_list.append({'transform': Psparse, 'reshape': newrefimg, 'voxmm': newvoxmm, 'name': rname, 'pairwise': r_pairwise, 'displayvol': None, 'keepdiag': r_keepdiag})
print('Output will include resolution %.1fmm (volume dimensions = %dx%dx%d)%s.' % (newvoxmm,newvolshape[0],newvolshape[1],newvolshape[2],r_pairwisestr))
if parcelfiles:
starttime_loadparc=time.time()
for p in parcelfiles:
p_pairwise=do_pairwise
p_keepdiag=False
p_displayvol=None
p_numroi=None
if p.find("?") >= 0:
#handle ?nopairwise option
p_opts=p.split("?")[1:]
p=p.split("?")[0]
if "nopairwise" in p_opts:
p_pairwise=False
if "keepdiag" in p_opts:
p_keepdiag=True
if any([x.startswith("displayvol=") for x in p_opts]):
displayvolfile=[x.split("=")[1] for x in p_opts if x.startswith("displayvol=")][0]
p_displayvol=nib.load(displayvolfile)
if any([x.startswith("numroi=") for x in p_opts]):
p_numroi_str=[x.split("=")[1] for x in p_opts if x.startswith("numroi=")][0]
try:
p_numroi=int(p_numroi_str)
except ValueError:
p_numroi=None
if p.find("=") >= 0:
[pfile,pname]=p.split("=")
else:
pfile=p
pname="parc%05d" % (len(Psparse_list))
Psparse=loadParcellation(filename=pfile, numroi=p_numroi, refimg=refimg, expected_shape=expected_shape, expected_shape_spm=expected_shape_spm)
Psparse_list.append({'transform': Psparse, 'reshape': None, 'voxmm': None, 'name': pname, 'pairwise': p_pairwise, 'displayvol': p_displayvol, 'keepdiag': p_keepdiag})
p_pairwisestr=""
if p_pairwise:
p_pairwisestr=" with pair-wise chacoconn"
if p_keepdiag:
p_pairwisestr+=", include diagonal in region-wise chacovol"
if isinstance(Psparse,list):
numroi=Psparse[0].shape[1]
print('Output will include subject-specific parcellation %s (%s, total parcels = %d)%s.' % (pname,pfile.split("/")[-1],numroi,p_pairwisestr))
else:
numroi=Psparse.shape[1]
print('Output will include parcellation %s (%s, total parcels = %d)%s.' % (pname,pfile.split("/")[-1],numroi,p_pairwisestr))
print('Loading parcellations took %s' % (durationToString(time.time()-starttime_loadparc)))
#if parcelfiles_subject_specific:
# for p in parcelfiles_subject_specific:
#250MB for 20 subjects as a 20x7M sparse matrix. would be 5GB for 420 subjects. only 48MB for 20subj compressed, so 1GB
#120MB for 50subj compressed
#130MB for 50subj UNCOMPRESSED when we masked by endpoints! = 1GB for 420 subjects
#Pdata=
if len(Psparse_list)==0:
Psparse_list=None
##################
chunks_in_lesion=np.unique(chunkidx_flat[Lmask!=0])
print('Total voxels in lesion mask: %d' % (np.sum(Lmask!=0)))
print('Total chunks in lesion mask: %d' % (len(chunks_in_lesion)))
missing_chunks=set(chunks_in_lesion)-set(unique_chunks)
if len(missing_chunks) > 0:
chunks_to_load=list(set(chunks_in_lesion) - missing_chunks)
print('Lesion includes %d chunks outside reference white-matter volume' % (len(missing_chunks)))
print('Total white-matter chunks in lesion mask: %d' % (len(chunks_to_load)))
else:
chunks_to_load=chunks_in_lesion
totalchunkbytes=np.sum(chunklist['chunkfilesize'][chunks_to_load])
totalchunkbytes_string=""
if totalchunkbytes >= 1024*1024*1024:
totalchunkbytes_string='%.2f GB' % (totalchunkbytes/(1024*1024*1024))
else:
totalchunkbytes_string='%.2f MB' % (totalchunkbytes/(1024*1024))
print('Total size for all %d chunk files: %s' % (len(chunks_to_load),totalchunkbytes_string))
chunkfile_fmt=chunkdir+'/chunk%05d.npz'
os.makedirs(chunkdir,exist_ok=True)
if do_download_nemofiles:
chunkfiles_to_download=[chunkdir+'/chunk%05d.npz' % (x) for x in chunks_to_load]
if do_force_redownload:
totalchunkbytes_download_string=totalchunkbytes_string
else:
chunks_to_download=[i for i,f in zip(chunks_to_load,chunkfiles_to_download) if not os.path.exists(f)]
chunkfiles_to_download=[chunkdir+'/chunk%05d.npz' % (x) for x in chunks_to_download]
totalchunkbytes_download=np.sum(chunklist['chunkfilesize'][chunks_to_download])
totalchunkbytes_download_string=""
if totalchunkbytes_download >= 1024*1024*1024:
totalchunkbytes_download_string='%.2f GB' % (totalchunkbytes_download/(1024*1024*1024))
else:
totalchunkbytes_download_string='%.2f MB' % (totalchunkbytes_download/(1024*1024))
if len(chunkfiles_to_download) > 0:
print('Downloading %d chunks (%s)' % (len(chunkfiles_to_download), totalchunkbytes_download_string), end='', flush=True)
starttime_download_chunks=time.time()
num_cpu=get_available_cpus()
multiproc_cores=num_cpu-1
P=multiprocessing.Pool(multiproc_cores, s3initialize)