Skip to content

Commit f6d870a

Browse files
authored
Merge pull request #417 from dPys/development
Development
2 parents 9128971 + fa850f8 commit f6d870a

20 files changed

+383
-194
lines changed

docs/usage.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,12 @@ Command-Line Arguments
119119
**********************
120120

121121
.. argparse::
122-
:module: pynets.cli.pynets_bids
122+
:module: pynets.cli.pynets_run
123123
:func: get_bids_parser
124124
:prog: pynets
125125

126126
.. argparse::
127-
:module: pynets.cli.pynets_run
127+
:module: pynets.cli.pynets_bids
128128
:func: get_parser
129129
:prog: pynets
130130

pynets/cli/pynets_bids.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
PyNets BIDS cli
33
"""
44
import bids
5+
56
from pynets.core.utils import as_list, merge_dicts
67

78

@@ -440,7 +441,6 @@ def main():
440441
import yaml
441442
import itertools
442443
from types import SimpleNamespace
443-
from pathlib import Path
444444
import pkg_resources
445445
from pynets.core.utils import flatten
446446
from pynets.cli.pynets_run import build_workflow

pynets/cli/pynets_run.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,18 @@ def get_parser():
400400
"is 20. If you wish to iterate the pipeline across multiple "
401401
"minimums, separate the list by space (e.g. 10 30 50).\n",
402402
)
403+
parser.add_argument(
404+
"-em",
405+
metavar="Error margin",
406+
default=12,
407+
nargs="+",
408+
help="(Hyperparameter): Distance (in the units of the streamlines, "
409+
"usually mm). If any coordinate in the streamline is within this "
410+
"distance from the center of any voxel in the ROI, the filtering "
411+
"criterion is set to True for this streamline, otherwise False. "
412+
"Defaults to the distance between the center of each voxel and "
413+
"the corner of the voxel.\n",
414+
)
403415
parser.add_argument(
404416
"-dg",
405417
metavar="Direction getter",
@@ -973,6 +985,20 @@ def build_workflow(args, retval):
973985
min_length_list = None
974986
else:
975987
min_length_list = None
988+
error_margin = args.em
989+
if error_margin:
990+
if (isinstance(error_margin, list)) and (len(error_margin) > 1):
991+
error_margin_list = error_margin
992+
error_margin = None
993+
elif error_margin == ["None"]:
994+
error_margin_list = None
995+
elif isinstance(error_margin, list):
996+
error_margin = error_margin[0]
997+
error_margin_list = None
998+
else:
999+
error_margin_list = None
1000+
else:
1001+
error_margin_list = None
9761002
directget = args.dg
9771003
if directget:
9781004
if (isinstance(directget, list)) and (len(directget) > 1):
@@ -1008,19 +1034,21 @@ def build_workflow(args, retval):
10081034
) as stream:
10091035
try:
10101036
hardcoded_params = yaml.load(stream)
1011-
maxcrossing = hardcoded_params["maxcrossing"][0]
1037+
maxcrossing = hardcoded_params['tracking']["maxcrossing"][0]
10121038
local_corr = hardcoded_params["clustering_local_conn"][0]
1013-
track_type = hardcoded_params["tracking_method"][0]
1014-
tiss_class = hardcoded_params["tissue_classifier"][0]
1015-
target_samples = hardcoded_params["tracking_samples"][0]
1039+
track_type = hardcoded_params['tracking']["tracking_method"][0]
1040+
tiss_class = hardcoded_params['tracking']["tissue_classifier"][0]
1041+
target_samples = hardcoded_params['tracking']["tracking_samples"][0]
10161042
use_parcel_naming = hardcoded_params["parcel_naming"][0]
1017-
step_list = hardcoded_params["step_list"]
1018-
curv_thr_list = hardcoded_params["curv_thr_list"]
1043+
step_list = hardcoded_params['tracking']["step_list"]
1044+
curv_thr_list = hardcoded_params['tracking']["curv_thr_list"]
10191045
nilearn_parc_atlases = hardcoded_params["nilearn_parc_atlases"]
10201046
nilearn_coord_atlases = hardcoded_params["nilearn_coord_atlases"]
10211047
nilearn_prob_atlases = hardcoded_params["nilearn_prob_atlases"]
10221048
local_atlases = hardcoded_params["local_atlases"]
10231049
template_name = hardcoded_params['template'][0]
1050+
roi_neighborhood_tol = \
1051+
hardcoded_params['tracking']["roi_neighborhood_tol"][0]
10241052

10251053
if track_type == "particle":
10261054
tiss_class = "cmc"
@@ -1800,6 +1828,24 @@ def build_workflow(args, retval):
18001828
else:
18011829
print(f"{Fore.GREEN}Iterating minimum streamline lengths:")
18021830
print(f"{Fore.BLUE}{', '.join(min_length_list)}")
1831+
if error_margin:
1832+
if float(roi_neighborhood_tol) <= float(error_margin):
1833+
raise ValueError(
1834+
'roi_neighborhood_tol preset cannot be less than '
1835+
'the value of the structural connectome error_margin'
1836+
' parameter.')
1837+
print(f"{Fore.GREEN}Using {Fore.BLUE}{error_margin}"
1838+
f"mm{Fore.GREEN} error margin...")
1839+
else:
1840+
for em in error_margin_list:
1841+
if float(roi_neighborhood_tol) <= float(em):
1842+
raise ValueError(
1843+
'roi_neighborhood_tol preset cannot be less than '
1844+
'the value of the structural connectome error_margin'
1845+
' parameter.')
1846+
print(f"{Fore.GREEN}Iterating minimum streamline lengths:")
1847+
print(f"{Fore.BLUE}{', '.join(error_margin_list)}")
1848+
18031849
if target_samples:
18041850
print(f"{Fore.GREEN}Using {Fore.BLUE}{target_samples} "
18051851
f"{Fore.GREEN}streamline samples...")
@@ -2217,6 +2263,7 @@ def init_wf_single_subject(
22172263
track_type,
22182264
min_length,
22192265
maxcrossing,
2266+
error_margin,
22202267
directget,
22212268
tiss_class,
22222269
runtime_dict,
@@ -2231,6 +2278,7 @@ def init_wf_single_subject(
22312278
waymask,
22322279
local_corr,
22332280
min_length_list,
2281+
error_margin_list,
22342282
extract_strategy,
22352283
extract_strategy_list,
22362284
outdir,
@@ -2384,6 +2432,7 @@ def init_wf_single_subject(
23842432
track_type,
23852433
min_length,
23862434
maxcrossing,
2435+
error_margin,
23872436
directget,
23882437
tiss_class,
23892438
runtime_dict,
@@ -2398,6 +2447,7 @@ def init_wf_single_subject(
23982447
waymask,
23992448
local_corr,
24002449
min_length_list,
2450+
error_margin_list,
24012451
extract_strategy,
24022452
extract_strategy_list,
24032453
outdir,
@@ -2650,6 +2700,7 @@ def wf_multi_subject(
26502700
track_type,
26512701
min_length,
26522702
maxcrossing,
2703+
error_margin,
26532704
directget,
26542705
tiss_class,
26552706
runtime_dict,
@@ -2664,6 +2715,7 @@ def wf_multi_subject(
26642715
waymask,
26652716
local_corr,
26662717
min_length_list,
2718+
error_margin_list,
26672719
extract_strategy,
26682720
extract_strategy_list,
26692721
outdir,
@@ -2776,6 +2828,7 @@ def wf_multi_subject(
27762828
track_type=track_type,
27772829
min_length=min_length,
27782830
maxcrossing=maxcrossing,
2831+
error_margin=error_margin,
27792832
directget=directget,
27802833
tiss_class=tiss_class,
27812834
runtime_dict=runtime_dict,
@@ -2790,6 +2843,7 @@ def wf_multi_subject(
27902843
waymask=waymask,
27912844
local_corr=local_corr,
27922845
min_length_list=min_length_list,
2846+
error_margin_list=error_margin_list,
27932847
extract_strategy=extract_strategy,
27942848
extract_strategy_list=extract_strategy_list,
27952849
outdir=subj_dir,
@@ -2893,6 +2947,7 @@ def wf_multi_subject(
28932947
track_type=track_type,
28942948
min_length=min_length,
28952949
maxcrossing=maxcrossing,
2950+
error_margin=error_margin,
28962951
directget=directget,
28972952
tiss_class=tiss_class,
28982953
runtime_dict=runtime_dict,
@@ -2907,6 +2962,7 @@ def wf_multi_subject(
29072962
waymask=waymask,
29082963
local_corr=local_corr,
29092964
min_length_list=min_length_list,
2965+
error_margin_list=error_margin_list,
29102966
extract_strategy=extract_strategy,
29112967
extract_strategy_list=extract_strategy_list,
29122968
outdir=subj_dir,
@@ -3008,6 +3064,7 @@ def wf_multi_subject(
30083064
track_type,
30093065
min_length,
30103066
maxcrossing,
3067+
error_margin,
30113068
directget,
30123069
tiss_class,
30133070
runtime_dict,
@@ -3022,6 +3079,7 @@ def wf_multi_subject(
30223079
waymask,
30233080
local_corr,
30243081
min_length_list,
3082+
error_margin_list,
30253083
extract_strategy,
30263084
extract_strategy_list,
30273085
outdir,
@@ -3195,6 +3253,7 @@ def wf_multi_subject(
31953253
track_type,
31963254
min_length,
31973255
maxcrossing,
3256+
error_margin,
31983257
directget,
31993258
tiss_class,
32003259
runtime_dict,
@@ -3209,6 +3268,7 @@ def wf_multi_subject(
32093268
waymask,
32103269
local_corr,
32113270
min_length_list,
3271+
error_margin_list,
32123272
extract_strategy,
32133273
extract_strategy_list,
32143274
subj_dir,

pynets/core/interfaces.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,7 @@ class _PlotStructInputSpec(BaseInterfaceInputSpec):
952952
track_type = traits.Any(mandatory=True)
953953
directget = traits.Any(mandatory=True)
954954
min_length = traits.Any(mandatory=True)
955-
955+
error_margin = traits.Any(mandatory=True)
956956

957957
class _PlotStructOutputSpec(BaseInterfaceInputSpec):
958958
"""Output interface wrapper for PlotStruct"""
@@ -998,6 +998,7 @@ def _run_interface(self, runtime):
998998
self.inputs.track_type,
999999
self.inputs.directget,
10001000
self.inputs.min_length,
1001+
self.inputs.error_margin
10011002
)
10021003

10031004
self._results["out"] = "None"
@@ -2525,6 +2526,7 @@ class _TrackingInputSpec(BaseInterfaceInputSpec):
25252526
step_list = traits.List(mandatory=True)
25262527
track_type = traits.Str(mandatory=True)
25272528
min_length = traits.Any(mandatory=True)
2529+
error_margin = traits.Any(mandatory=True)
25282530
maxcrossing = traits.Any(mandatory=True)
25292531
directget = traits.Str(mandatory=True)
25302532
conn_model = traits.Str(mandatory=True)
@@ -2549,8 +2551,6 @@ class _TrackingInputSpec(BaseInterfaceInputSpec):
25492551
fa_path = File(exists=True, mandatory=True)
25502552
waymask = traits.Any(mandatory=False)
25512553
t1w2dwi = File(exists=True, mandatory=True)
2552-
roi_neighborhood_tol = traits.Any(10, mandatory=True, usedefault=True)
2553-
sphere = traits.Str('repulsion724', mandatory=True, usedefault=True)
25542554

25552555

25562556
class _TrackingOutputSpec(TraitedSpec):
@@ -2583,9 +2583,8 @@ class _TrackingOutputSpec(TraitedSpec):
25832583
dm_path = File(exists=True, mandatory=True)
25842584
directget = traits.Str(mandatory=True)
25852585
labels_im_file = File(exists=True, mandatory=True)
2586-
roi_neighborhood_tol = traits.Any()
25872586
min_length = traits.Any()
2588-
2587+
error_margin = traits.Any()
25892588

25902589
class Tracking(SimpleInterface):
25912590
"""Interface wrapper for Tracking"""
@@ -2617,7 +2616,10 @@ def _run_interface(self, runtime):
26172616
pkg_resources.resource_filename("pynets", "runconfig.yaml"), "r"
26182617
) as stream:
26192618
hardcoded_params = yaml.load(stream)
2620-
use_life = hardcoded_params["use_life"][0]
2619+
use_life = hardcoded_params['tracking']["use_life"][0]
2620+
roi_neighborhood_tol = hardcoded_params['tracking']["roi_neighborhood_tol"][0]
2621+
sphere = hardcoded_params['tracking']["sphere"][0]
2622+
26212623
stream.close()
26222624

26232625
dir_path = utils.do_dir_path(
@@ -2682,12 +2684,12 @@ def _run_interface(self, runtime):
26822684
"_",
26832685
"%s"
26842686
% (
2685-
"%s%s" % (self.inputs.node_size, "mm_")
2687+
"%s%s" % (self.inputs.node_size, "mm")
26862688
if (
26872689
(self.inputs.node_size != "parc")
26882690
and (self.inputs.node_size is not None)
26892691
)
2690-
else "parc_"
2692+
else "parc"
26912693
),
26922694
".npy",
26932695
)
@@ -2829,16 +2831,17 @@ def _run_interface(self, runtime):
28292831

28302832
# Iteratively build a list of streamlines for each ROI while tracking
28312833
print(
2832-
f"{Fore.GREEN}Target number of samples: {Fore.BLUE} "
2834+
f"{Fore.GREEN}Target number of cumulative streamlines: "
2835+
f"{Fore.BLUE} "
28332836
f"{self.inputs.target_samples}"
28342837
)
28352838
print(Style.RESET_ALL)
28362839
print(
2837-
f"{Fore.GREEN}Using curvature threshold(s): {Fore.BLUE} "
2840+
f"{Fore.GREEN}Curvature threshold(s): {Fore.BLUE} "
28382841
f"{self.inputs.curv_thr_list}"
28392842
)
28402843
print(Style.RESET_ALL)
2841-
print(f"{Fore.GREEN}Using step size(s): {Fore.BLUE} "
2844+
print(f"{Fore.GREEN}Step size(s): {Fore.BLUE} "
28422845
f"{self.inputs.step_list}")
28432846
print(Style.RESET_ALL)
28442847
print(f"{Fore.GREEN}Tracking type: {Fore.BLUE} "
@@ -2865,13 +2868,13 @@ def _run_interface(self, runtime):
28652868
atlas_data_wm_gm_int,
28662869
parcels,
28672870
model,
2868-
get_sphere(self.inputs.sphere),
2871+
get_sphere(sphere),
28692872
self.inputs.directget,
28702873
self.inputs.curv_thr_list,
28712874
self.inputs.step_list,
28722875
self.inputs.track_type,
28732876
self.inputs.maxcrossing,
2874-
int(self.inputs.roi_neighborhood_tol),
2877+
int(roi_neighborhood_tol),
28752878
self.inputs.min_length,
28762879
waymask_data,
28772880
B0_mask_data,
@@ -2880,7 +2883,7 @@ def _run_interface(self, runtime):
28802883
self.inputs.tiss_class, B0_mask_tmp_path
28812884
)
28822885

2883-
del model, parcels, atlas_data_wm_gm_int, waymask_data
2886+
del model, parcels, atlas_data_wm_gm_int
28842887
gc.collect()
28852888

28862889
# Save streamlines to trk
@@ -2927,10 +2930,16 @@ def _run_interface(self, runtime):
29272930
dwi_img = nib.load(dwi_file_tmp_path, mmap=True)
29282931
dwi_data = dwi_img.get_fdata().astype('float32')
29292932
orig_count = len(streamlines)
2933+
2934+
if self.inputs.waymask:
2935+
mask_data = waymask_data
2936+
else:
2937+
mask_data = nib.load(wm_in_dwi_tmp_path
2938+
).get_fdata().astype('float32')
29302939
try:
29312940
streamlines = evaluate_streamline_plausibility(
2932-
dwi_data, gtab, B0_mask_data, streamlines,
2933-
sphere=self.inputs.sphere)
2941+
dwi_data, gtab, mask_data, streamlines,
2942+
sphere=sphere)
29342943
except BaseException:
29352944
print(f"Linear Fascicle Evaluation failed. Visually checking "
29362945
f"streamlines output {namer_dir}/{op.basename(streams)}"
@@ -2940,7 +2949,7 @@ def _run_interface(self, runtime):
29402949
'the tractogram!')
29412950
del dwi_data
29422951

2943-
del B0_mask_data
2952+
del B0_mask_data, waymask_data
29442953

29452954
stf = StatefulTractogram(
29462955
streamlines,
@@ -3011,9 +3020,8 @@ def _run_interface(self, runtime):
30113020
self._results["dm_path"] = dm_path
30123021
self._results["directget"] = self.inputs.directget
30133022
self._results["labels_im_file"] = labels_im_file_tmp_path
3014-
self._results["roi_neighborhood_tol"] = \
3015-
self.inputs.roi_neighborhood_tol
30163023
self._results["min_length"] = self.inputs.min_length
3024+
self._results["error_margin"] = self.inputs.error_margin
30173025

30183026
tmp_files = [B0_mask_tmp_path, gtab_file_tmp_path,
30193027
labels_im_file_tmp_path_wm_gm_int, dwi_file_tmp_path]

0 commit comments

Comments
 (0)