Skip to content

Commit 9336dee

Browse files
committed
new split and merge interfaces
1 parent 0de90cf commit 9336dee

File tree

3 files changed

+233
-68
lines changed

3 files changed

+233
-68
lines changed

nipype/algorithms/misc.py

Lines changed: 230 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _list_outputs(self):
215215
self._gen_output_filename(fname))
216216
return outputs
217217

218+
218219
class CreateNiftiInputSpec(BaseInterfaceInputSpec):
219220
data_file = File(exists=True, mandatory=True, desc="ANALYZE img file")
220221
header_file = File(
@@ -421,7 +422,7 @@ class Matlab2CSV(BaseInterface):
421422
Example
422423
-------
423424
424-
>>> import nipype.algorithms.misc as misc
425+
>>> from nipype.algorithms import misc
425426
>>> mat2csv = misc.Matlab2CSV()
426427
>>> mat2csv.inputs.in_file = 'cmatrix.mat'
427428
>>> mat2csv.run() # doctest: +SKIP
@@ -615,7 +616,7 @@ class MergeCSVFiles(BaseInterface):
615616
Example
616617
-------
617618
618-
>>> import nipype.algorithms.misc as misc
619+
>>> from nipype.algorithms import misc
619620
>>> mat2csv = misc.MergeCSVFiles()
620621
>>> mat2csv.inputs.in_files = ['degree.mat','clustering.mat']
621622
>>> mat2csv.inputs.column_headings = ['degree','clustering']
@@ -744,7 +745,7 @@ class AddCSVColumn(BaseInterface):
744745
Example
745746
-------
746747
747-
>>> import nipype.algorithms.misc as misc
748+
>>> from nipype.algorithms import misc
748749
>>> addcol = misc.AddCSVColumn()
749750
>>> addcol.inputs.in_file = 'degree.csv'
750751
>>> addcol.inputs.extra_column_heading = 'group'
@@ -797,9 +798,11 @@ def __setattr__(self, key, value):
797798
self._outputs[key] = value
798799
super(AddCSVRowInputSpec, self).__setattr__(key, value)
799800

801+
800802
class AddCSVRowOutputSpec(TraitedSpec):
801803
csv_file = File(desc='Output CSV file containing rows ')
802804

805+
803806
class AddCSVRow(BaseInterface):
804807
"""Simple interface to add an extra row to a csv file
805808
@@ -814,7 +817,7 @@ class AddCSVRow(BaseInterface):
814817
Example
815818
-------
816819
817-
>>> import nipype.algorithms.misc as misc
820+
>>> from nipype.algorithms import misc
818821
>>> addrow = misc.AddCSVRow()
819822
>>> addrow.inputs.in_file = 'scores.csv'
820823
>>> addrow.inputs.si = 0.74
@@ -919,7 +922,7 @@ class CalculateNormalizedMoments(BaseInterface):
919922
Example
920923
-------
921924
922-
>>> import nipype.algorithms.misc as misc
925+
>>> from nipype.algorithms import misc
923926
>>> skew = misc.CalculateNormalizedMoments()
924927
>>> skew.inputs.moment = 3
925928
>>> skew.inputs.timeseries_file = 'timeseries.txt'
@@ -956,12 +959,116 @@ def calc_moments(timeseries_file, moment):
956959
return np.where(zero, 0, m3 / m2**(moment/2.0))
957960

958961

962+
class SplitROIsInputSpec(TraitedSpec):
963+
in_file = File(exists=True, mandatory=True,
964+
desc='file to be splitted')
965+
in_mask = File(exists=True, desc='only process files inside mask')
966+
roi_size = traits.Tuple(traits.Int, traits.Int, traits.Int,
967+
desc='desired ROI size')
968+
969+
970+
class SplitROIsOutputSpec(TraitedSpec):
971+
out_files = OutputMultiPath(File(exists=True),
972+
desc='the resulting ROIs')
973+
out_masks = OutputMultiPath(File(exists=True),
974+
desc='a mask indicating valid values')
975+
out_index = OutputMultiPath(File(exists=True),
976+
desc='arrays keeping original locations')
977+
978+
979+
class SplitROIs(BaseInterface):
980+
"""
981+
Splits a 3D image in small chunks to enable parallel processing.
982+
ROIs keep time series structure in 4D images.
983+
984+
Example
985+
-------
986+
987+
>>> from nipype.algorithms import misc
988+
>>> rois = misc.SplitROIs()
989+
>>> rois.inputs.in_file = 'diffusion.nii'
990+
>>> rois.inputs.in_mask = 'mask.nii'
991+
>>> rois.run() # doctest: +SKIP
992+
993+
"""
994+
input_spec = SplitROIsInputSpec
995+
output_spec = SplitROIsOutputSpec
996+
997+
def _run_interface(self, runtime):
998+
mask = None
999+
roisize = None
1000+
self._outnames = {}
1001+
1002+
if isdefined(self.inputs.in_mask):
1003+
mask = self.inputs.in_mask
1004+
if isdefined(self.inputs.roi_size):
1005+
roisize = self.inputs.roi_size
1006+
1007+
res = split_rois(self.inputs.in_file,
1008+
mask, roisize)
1009+
self._outnames['out_files'] = res[0]
1010+
self._outnames['out_masks'] = res[1]
1011+
self._outnames['out_index'] = res[2]
1012+
return runtime
1013+
1014+
def _list_outputs(self):
1015+
outputs = self.output_spec().get()
1016+
for k, v in self._outnames.iteritems():
1017+
outputs[k] = v
1018+
return outputs
1019+
1020+
1021+
class MergeROIsInputSpec(TraitedSpec):
1022+
in_files = InputMultiPath(File(exists=True, mandatory=True,
1023+
desc='files to be re-merged'))
1024+
in_index = InputMultiPath(File(exists=True, mandatory=True),
1025+
desc='array keeping original locations')
1026+
in_reference = File(exists=True, desc='reference file')
1027+
1028+
1029+
class MergeROIsOutputSpec(TraitedSpec):
1030+
merged_file = File(exists=True, desc='the recomposed file')
1031+
1032+
1033+
class MergeROIs(BaseInterface):
1034+
"""
1035+
Splits a 3D image in small chunks to enable parallel processing.
1036+
ROIs keep time series structure in 4D images.
1037+
1038+
Example
1039+
-------
1040+
1041+
>>> from nipype.algorithms import misc
1042+
>>> rois = misc.MergeROIs()
1043+
>>> rois.inputs.in_files = ['roi%02d.nii' % i for i in xrange(1, 6)]
1044+
>>> rois.inputs.in_reference = 'mask.nii'
1045+
>>> rois.inputs.in_index = ['roi%02d_idx.npz' % i for i in xrange(1, 6)]
1046+
>>> rois.run() # doctest: +SKIP
1047+
1048+
"""
1049+
input_spec = MergeROIsInputSpec
1050+
output_spec = MergeROIsOutputSpec
1051+
1052+
def _run_interface(self, runtime):
1053+
res = merge_rois(self.inputs.in_files,
1054+
self.inputs.in_index,
1055+
self.inputs.in_reference)
1056+
self._merged = res
1057+
return runtime
1058+
1059+
def _list_outputs(self):
1060+
outputs = self.output_spec().get()
1061+
outputs['merged_file'] = self._merged
1062+
return outputs
1063+
1064+
9591065
class NormalizeProbabilityMapSetInputSpec(TraitedSpec):
9601066
in_files = InputMultiPath(File(exists=True, mandatory=True,
9611067
desc='The tpms to be normalized') )
9621068
in_mask = File(exists=True,
9631069
desc='Masked voxels must sum up 1.0, 0.0 otherwise.')
9641070

1071+
9651072
class NormalizeProbabilityMapSetOutputSpec(TraitedSpec):
9661073
out_files = OutputMultiPath(File(exists=True),
9671074
desc="normalized maps")
@@ -976,7 +1083,7 @@ class NormalizeProbabilityMapSet(BaseInterface):
9761083
Example
9771084
-------
9781085
979-
>>> import nipype.algorithms.misc as misc
1086+
>>> from nipype.algorithms import misc
9801087
>>> normalize = misc.NormalizeProbabilityMapSet()
9811088
>>> normalize.inputs.in_files = [ 'tpm_00.nii.gz', 'tpm_01.nii.gz', 'tpm_02.nii.gz' ]
9821089
>>> normalize.inputs.in_mask = 'tpms_msk.nii.gz'
@@ -1059,6 +1166,123 @@ def normalize_tpms( in_files, in_mask=None, out_files=[] ):
10591166
return out_files
10601167

10611168

1169+
def split_rois(in_file, mask=None, roishape=None):
1170+
"""
1171+
Splits an image in ROIs for parallel processing
1172+
"""
1173+
import nibabel as nb
1174+
import numpy as np
1175+
from math import sqrt, ceil
1176+
import os.path as op
1177+
1178+
if roishape is None:
1179+
roishape = (10, 10, 1)
1180+
1181+
im = nb.load(in_file)
1182+
imshape = im.get_shape()
1183+
dshape = imshape[:3]
1184+
nvols = imshape[-1]
1185+
roisize = roishape[0] * roishape[1] * roishape[2]
1186+
droishape = (roishape[0], roishape[1], roishape[2], nvols)
1187+
1188+
if mask is not None:
1189+
mask = nb.load(mask).get_data()
1190+
mask[mask > 0] = 1
1191+
mask[mask < 1] = 0
1192+
else:
1193+
mask = np.ones(dshape)
1194+
1195+
mask = mask.reshape(-1).astype(np.uint8)
1196+
nzels = np.nonzero(mask)
1197+
els = np.sum(mask)
1198+
nrois = int(ceil(els/roisize))
1199+
1200+
data = im.get_data().reshape((mask.size, -1))
1201+
data = np.squeeze(data.take(nzels, axis=0))
1202+
nvols = data.shape[-1]
1203+
1204+
roidefname = op.abspath('onesmask.nii.gz')
1205+
nb.Nifti1Image(np.ones(roishape, dtype=np.uint8), None,
1206+
None).to_filename(roidefname)
1207+
1208+
out_files = []
1209+
out_mask = []
1210+
out_idxs = []
1211+
1212+
for i in xrange(nrois):
1213+
first = i * roisize
1214+
last = (i+1) * roisize
1215+
fill = 0
1216+
1217+
if last > els:
1218+
fill = last - els
1219+
last = els
1220+
1221+
droi = data[first:last, ...]
1222+
iname = op.abspath('roi%010d_idx' % i)
1223+
out_idxs.append(iname+'.npz')
1224+
np.savez(iname, (nzels[0][first:last],))
1225+
1226+
if fill > 0:
1227+
droi = np.vstack((droi, np.zeros((fill, nvols), dtype=np.float32)))
1228+
partialmsk = np.ones((roisize,), dtype=np.uint8)
1229+
partialmsk[-fill:] = 0
1230+
partname = op.abspath('partialmask.nii.gz')
1231+
nb.Nifti1Image(partialmsk.reshape(roishape), None,
1232+
None).to_filename(partname)
1233+
out_mask.append(partname)
1234+
else:
1235+
out_mask.append(roidefname)
1236+
1237+
fname = op.abspath('roi%010d.nii.gz' % i)
1238+
nb.Nifti1Image(droi.reshape(droishape),
1239+
None, None).to_filename(fname)
1240+
out_files.append(fname)
1241+
return out_files, out_mask, out_idxs
1242+
1243+
1244+
def merge_rois(in_files, in_idxs, in_ref,
1245+
dtype=None, out_file=None):
1246+
"""
1247+
Re-builds an image resulting from a parallelized processing
1248+
"""
1249+
import nibabel as nb
1250+
import numpy as np
1251+
import os.path as op
1252+
1253+
if out_file is None:
1254+
out_file = op.abspath('merged.nii.gz')
1255+
1256+
ref = nb.load(in_ref)
1257+
aff = ref.get_affine()
1258+
hdr = ref.get_header().copy()
1259+
rsh = ref.get_shape()
1260+
del ref
1261+
npix = rsh[0] * rsh[1] * rsh[2]
1262+
ndirs = nb.load(in_files[0]).get_shape()[-1]
1263+
newshape = (rsh[0], rsh[1], rsh[2], ndirs)
1264+
data = np.zeros((npix, ndirs), dtype=dtype)
1265+
for cname, iname in zip(in_files, in_idxs):
1266+
with np.load(iname) as f:
1267+
idxs = np.squeeze(f['arr_0'])
1268+
cdata = nb.load(cname).get_data().reshape(-1, ndirs)
1269+
nels = len(idxs)
1270+
idata = (idxs, )
1271+
data[idata, ...] = cdata[0:nels, ...]
1272+
1273+
if dtype is None:
1274+
dtype = np.float32
1275+
1276+
hdr.set_data_dtype(dtype)
1277+
hdr.set_xyzt_units('mm', 'sec')
1278+
hdr.set_data_shape(newshape)
1279+
nb.Nifti1Image(data.reshape(newshape).astype(dtype),
1280+
aff, hdr).to_filename(out_file)
1281+
1282+
return out_file
1283+
1284+
1285+
10621286
# Deprecated interfaces ---------------------------------------------------------
10631287
class Distance( nam.Distance ):
10641288
"""Calculates distance between two volumes.

nipype/workflows/dmri/fsl/dti.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -121,63 +121,3 @@ def merge_and_mean(name='mm'):
121121
(mean, outputnode, [('out_file', 'mean')])
122122
])
123123
return wf
124-
125-
126-
def gen_chunks(dwi, mask, nvoxels=100):
127-
import nibabel as nb
128-
import numpy as np
129-
from math import sqrt, ceil
130-
131-
mask = nb.load(mask).get_data()
132-
mask[mask > 0] = 1
133-
mask[mask < 1] = 0
134-
135-
dshape = mask.shape
136-
mask = mask.reshape(-1).astype(np.uint8)
137-
138-
nzels = np.nonzero(mask)
139-
np.savez('nonzeroidx', nzels)
140-
els = np.sum(mask)
141-
142-
chunkside = int(ceil(sqrt(nvoxels)))
143-
chshape = (chunkside, chunkside, 1)
144-
chunkels = chunkside**2
145-
nchunks = int(ceil(els/chunkels))
146-
147-
data = nb.load(dwi).get_data().astype(np.float32)
148-
data = np.squeeze(data.reshape((mask.size, -1)).take(nzels, axis=0))
149-
150-
out_files = []
151-
out_masks = []
152-
153-
for i in xrange(nchunks):
154-
first = i * chunkels
155-
last = (i+1) * chunkels
156-
fill = 0
157-
158-
if last > els:
159-
fill = last - els
160-
last = els
161-
162-
datchunk = data[first:last, ...]
163-
mskchunk = np.ones((last-first), dtype=np.uint8)
164-
165-
if fill > 0:
166-
print 'filling %d elements' % fill
167-
mskchunk = np.hstack((mskchunk, np.zeros((fill,), dtype=np.uint8)))
168-
datchunk = np.vstack((datchunk, np.zeros((fill,
169-
data.shape[-1]), dtype=np.float32)))
170-
171-
mname = 'mask%05d.nii.gz' % i
172-
nb.Nifti1Image(mskchunk.reshape(chshape),
173-
None, None).to_filename(mname)
174-
out_masks.append(mname)
175-
176-
datashape = [s for s in chshape] + [-1]
177-
fname = 'chunk%05d.nii.gz' % i
178-
nb.Nifti1Image(datchunk.reshape(tuple(datashape)),
179-
None, None).to_filename(fname)
180-
out_files.append(fname)
181-
182-
return out_files, out_masks
183-

0 commit comments

Comments
 (0)