Skip to content

Commit effa285

Browse files
paralellizing gibbs
1 parent 202e435 commit effa285

File tree

3 files changed

+35
-19
lines changed

3 files changed

+35
-19
lines changed

micaflow/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,15 @@ def format_help(self):
657657
default=3,
658658
help="Dimension of the DWI image referring to shells (default: 3)",
659659
)
660+
bias_corr_parser.add_argument(
661+
"--threads",
662+
type=int,
663+
default=1,
664+
help="Number of threads to use for bias correction (default: 1)",
665+
)
666+
bias_corr_parser.add_argument(
667+
"--gibbs", action="store_true", help="Apply Gibbs ringing correction"
668+
)
660669

661670
# DICE Calculator command
662671
dice_parser = subparsers.add_parser(

micaflow/resources/Snakefile

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,9 @@ rule bias_field_correction:
269269
mask = rules.skull_strip_t1w.output.mask
270270
output:
271271
corrected = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_T1w-space_T1w.nii.gz"
272-
threads: LIGHT_THREADS
272+
threads: HEAVY_THREADS
273273
shell:
274-
"micaflow bias_correction -i {input.image} -o {output.corrected} -m {input.mask} --gibbs"
274+
"micaflow bias_correction -i {input.image} -o {output.corrected} -m {input.mask} --gibbs --threads {threads}"
275275

276276
# Place these rules in a conditional block to only run when FLAIR is available
277277
if RUN_FLAIR:
@@ -302,9 +302,9 @@ if RUN_FLAIR:
302302
mask = rules.skull_strip_flair.output.mask
303303
output:
304304
corrected = f"{OUT_DIR}/{SUBJECT}/{SESSION}/anat/{SUBJECT}_{SESSION}_FLAIR-space_FLAIR.nii.gz"
305-
threads: LIGHT_THREADS
305+
threads: HEAVY_THREADS
306306
shell:
307-
"micaflow bias_correction -i {input.image} -o {output.corrected} -m {input.mask} --gibbs"
307+
"micaflow bias_correction -i {input.image} -o {output.corrected} -m {input.mask} --gibbs --threads {threads}"
308308

309309
rule registration_t1w:
310310
input:
@@ -620,7 +620,7 @@ if RUN_DWI:
620620
output:
621621
corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_bias-corrected_DWI.nii.gz",
622622
b0_corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_biascorrected-b0.nii.gz"
623-
threads: LIGHT_THREADS
623+
threads: HEAVY_THREADS
624624
shell:
625625
"""
626626
micaflow bias_correction \
@@ -629,7 +629,8 @@ if RUN_DWI:
629629
--b0-output {output.b0_corrected} \
630630
--output {output.corrected} \
631631
--shell-dimension {SHELL_DIMENSION} \
632-
--gibbs
632+
--gibbs \
633+
--threads {threads}
633634
"""
634635
rule b0_synthseg:
635636
input:
@@ -958,7 +959,7 @@ if RUN_DWI:
958959
output:
959960
corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_denoised_bias-corrected_DWI.nii.gz",
960961
b0_corrected = f"{TEMP_DIR}/{SUBJECT}_{SESSION}_biascorrected-b0.nii.gz"
961-
threads: LIGHT_THREADS
962+
threads: HEAVY_THREADS
962963
shell:
963964
"""
964965
micaflow bias_correction \
@@ -968,7 +969,8 @@ if RUN_DWI:
968969
--output {output.corrected} \
969970
--mask {input.mask} \
970971
--shell-dimension {SHELL_DIMENSION} \
971-
--gibbs
972+
--gibbs \
973+
--threads {threads}
972974
"""
973975

974976
rule dwi_registration:

micaflow/scripts/bias_correction.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def print_help_message():
537537
print(help_text)
538538

539539

540-
def bias_field_correction_3d(image_path, output_path, mask_path=None, gibbs=False):
540+
def bias_field_correction_3d(image_path, output_path, mask_path=None, gibbs=False, threads=1):
541541
"""
542542
Perform N4 bias field correction on a 3D anatomical image.
543543
@@ -602,11 +602,9 @@ def bias_field_correction_3d(image_path, output_path, mask_path=None, gibbs=Fals
602602
img = ants.image_read(image_path)
603603

604604
if gibbs:
605-
if not HAS_DIPY:
606-
raise ImportError("DIPY is required for Gibbs removal. Please install dipy.")
607605
print(f"{CYAN}Running Gibbs ringing removal...{RESET}")
608606
arr = img.numpy()
609-
gibbs_removal(arr, slice_axis=2, n_points=3, inplace=True, num_processes=1)
607+
gibbs_removal(arr, slice_axis=2, n_points=3, inplace=True, num_processes=threads)
610608
img = img.new_image_like(arr)
611609

612610
print(f" Image shape: {img.shape}")
@@ -630,7 +628,7 @@ def bias_field_correction_3d(image_path, output_path, mask_path=None, gibbs=Fals
630628

631629

632630
def bias_field_correction_4d(image_path, mask_path=None, output_path=None,
633-
b0_path=None, b0_corrected_path=None, shell_dimension=3, gibbs=False):
631+
b0_path=None, b0_corrected_path=None, shell_dimension=3, gibbs=False, threads=1):
634632
"""
635633
Apply N4 bias field correction to a 4D diffusion image.
636634
@@ -726,7 +724,7 @@ def bias_field_correction_4d(image_path, mask_path=None, output_path=None,
726724
if gibbs:
727725
print(f"{CYAN}Running Gibbs ringing removal on 4D data...{RESET}")
728726
arr = img.numpy()
729-
gibbs_removal(arr, slice_axis=2, n_points=3, inplace=True, num_processes=1)
727+
gibbs_removal(arr, slice_axis=2, n_points=3, inplace=True, num_processes=threads)
730728
img = img.new_image_like(arr)
731729

732730
img_data = img.numpy()
@@ -907,7 +905,7 @@ def needs_resampling(img1, img2):
907905

908906

909907
def run_bias_field_correction(image_path, output_path, mask_path=None, mode="auto",
910-
b0_path=None, b0_corrected_path=None, shell_dimension=3, gibbs=False):
908+
b0_path=None, b0_corrected_path=None, shell_dimension=3, gibbs=False, threads=1):
911909
"""
912910
Run bias field correction with automatic dimensionality detection.
913911
@@ -995,6 +993,8 @@ def run_bias_field_correction(image_path, output_path, mask_path=None, mode="aut
995993
bias_field_correction_3d : 3D-specific implementation
996994
bias_field_correction_4d : 4D-specific implementation
997995
"""
996+
os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = str(threads)
997+
os.environ["OMP_NUM_THREADS"] = str(threads) # OpenMP threads for ANTs
998998
# If auto mode, determine if image is 3D or 4D
999999
print(f"{CYAN}Detecting image dimensionality...{RESET}")
10001000
img = ants.image_read(image_path)
@@ -1049,10 +1049,10 @@ def run_bias_field_correction(image_path, output_path, mask_path=None, mode="aut
10491049
if mode == "4d":
10501050
return bias_field_correction_4d(
10511051
image_path, mask_path, output_path,
1052-
b0_path, b0_corrected_path, shell_dimension, gibbs
1052+
b0_path, b0_corrected_path, shell_dimension, gibbs, threads
10531053
)
10541054
else: # 3d
1055-
return bias_field_correction_3d(image_path, output_path, mask_path, gibbs)
1055+
return bias_field_correction_3d(image_path, output_path, mask_path, gibbs, threads)
10561056
finally:
10571057
# Clean up temporary files
10581058
if temp_mask_path and os.path.exists(temp_mask_path):
@@ -1102,7 +1102,11 @@ def run_bias_field_correction(image_path, output_path, mask_path=None, mode="aut
11021102
"--gibbs", action="store_true",
11031103
help="Apply Gibbs ringing removal (requires DIPY)."
11041104
)
1105-
1105+
parser.add_argument(
1106+
"--threads", type=int, default=1,
1107+
help="Number of threads to use for processing (default: 1)."
1108+
)
1109+
11061110
args = parser.parse_args()
11071111

11081112
try:
@@ -1136,7 +1140,8 @@ def run_bias_field_correction(image_path, output_path, mask_path=None, mode="aut
11361140
args.b0,
11371141
args.b0_output,
11381142
args.shell_dimension,
1139-
args.gibbs
1143+
args.gibbs,
1144+
args.threads
11401145
)
11411146

11421147
print(f"\n{GREEN}{BOLD}Bias field correction completed successfully!{RESET}")

0 commit comments

Comments
 (0)