Skip to content

Commit 1baedcc

Browse files
authored
Merge pull request #119 from sarthakpati/greedy
Added `greedy` and `elastix` registrators
2 parents c417b8c + 02d46ff commit 1baedcc

File tree

9 files changed

+258
-1
lines changed

9 files changed

+258
-1
lines changed

brainles_preprocessing/registration/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@
1717

1818

1919
from .niftyreg.niftyreg import NiftyRegRegistrator
20+
21+
from .elastix.elastix import ElastixRegistrator
22+
23+
from .greedy.greedy import GreedyRegistrator

brainles_preprocessing/registration/elastix/__init__.py

Whitespace-only changes.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# TODO add typing and docs
2+
from typing import Optional
3+
import os
4+
5+
import itk
6+
7+
from brainles_preprocessing.registration.registrator import Registrator
8+
from brainles_preprocessing.utils import check_and_add_suffix
9+
10+
11+
class ElastixRegistrator(Registrator):
12+
def __init__(
13+
self,
14+
):
15+
pass
16+
17+
def register(
18+
self,
19+
fixed_image_path: str,
20+
moving_image_path: str,
21+
transformed_image_path: str,
22+
matrix_path: str,
23+
log_file_path: Optional[str] = None,
24+
parameter_object: Optional[itk.ParameterObject] = None,
25+
) -> None:
26+
"""
27+
Register images using elastix.
28+
29+
Args:
30+
fixed_image_path (str): Path to the fixed image.
31+
moving_image_path (str): Path to the moving image.
32+
transformed_image_path (str): Path to the transformed image (output).
33+
matrix_path (str): Path to the transformation matrix (output). This gets overwritten if it already exists.
34+
log_file_path (Optional[str]): Path to the log file.
35+
parameter_object (Optional[itk.ParameterObject]): The parameter object for elastix registration.
36+
"""
37+
# initialize parameter object
38+
if parameter_object is None:
39+
parameter_object = self.__initialize_parameter_object()
40+
# add .txt suffix to the matrix path if it doesn't have any extension
41+
matrix_path = check_and_add_suffix(matrix_path, ".txt")
42+
43+
# read images as itk images
44+
fixed_image = itk.imread(fixed_image_path)
45+
moving_image = itk.imread(moving_image_path)
46+
47+
if log_file_path is not None:
48+
# split log_file_path
49+
log_path, log_file = os.path.split(log_file_path)
50+
result_image, result_transform_params = itk.elastix_registration_method(
51+
fixed_image,
52+
moving_image,
53+
parameter_object=parameter_object,
54+
log_to_file=True,
55+
log_file_name=log_file,
56+
output_directory=log_path,
57+
)
58+
else:
59+
result_image, result_transform_params = itk.elastix_registration_method(
60+
fixed_image,
61+
moving_image,
62+
parameter_object=parameter_object,
63+
log_to_console=True,
64+
)
65+
66+
itk.imwrite(result_image, transformed_image_path)
67+
68+
if not os.path.exists(matrix_path):
69+
result_transform_params.WriteParameterFile(
70+
result_transform_params.GetParameterMap(0),
71+
matrix_path,
72+
)
73+
74+
def transform(
75+
self,
76+
fixed_image_path: str,
77+
moving_image_path: str,
78+
transformed_image_path: str,
79+
matrix_path: str,
80+
log_file_path: Optional[str] = None,
81+
) -> None:
82+
"""
83+
Apply a transformation using elastix.
84+
85+
Args:
86+
fixed_image_path (str): Path to the fixed image.
87+
moving_image_path (str): Path to the moving image.
88+
transformed_image_path (str): Path to the transformed image (output).
89+
matrix_path (str): Path to the transformation matrix (output). This gets overwritten if it already exists.
90+
log_file_path (Optional[str]): Path to the log file.
91+
"""
92+
parameter_object = self.__initialize_parameter_object()
93+
94+
# check if the matrix file exists
95+
if os.path.exists(matrix_path):
96+
parameter_object.SetParameter(
97+
0, "InitialTransformParametersFileName", matrix_path
98+
)
99+
100+
self.register(
101+
fixed_image_path,
102+
moving_image_path,
103+
transformed_image_path,
104+
matrix_path,
105+
log_file_path,
106+
parameter_object,
107+
)
108+
109+
def __initialize_parameter_object(self) -> itk.ParameterObject:
110+
"""
111+
Initialize the parameter object for elastix registration.
112+
113+
Returns:
114+
itk.ParameterObject: The parameter object for registration.
115+
"""
116+
parameter_object = itk.ParameterObject.New()
117+
default_rigid_parameter_map = parameter_object.GetDefaultParameterMap("rigid")
118+
parameter_object.AddParameterMap(default_rigid_parameter_map)
119+
return parameter_object

brainles_preprocessing/registration/greedy/__init__.py

Whitespace-only changes.
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# TODO add typing and docs
2+
from typing import Optional
3+
import contextlib
4+
import os
5+
6+
from picsl_greedy import Greedy3D
7+
8+
from brainles_preprocessing.registration.registrator import Registrator
9+
from brainles_preprocessing.utils import check_and_add_suffix
10+
11+
12+
class GreedyRegistrator(Registrator):
13+
def __init__(
14+
self,
15+
):
16+
pass
17+
18+
def register(
19+
self,
20+
fixed_image_path: str,
21+
moving_image_path: str,
22+
transformed_image_path: str,
23+
matrix_path: str,
24+
log_file_path: Optional[str] = None,
25+
) -> None:
26+
"""
27+
Register images using greedy. Ref: https://pypi.org/project/picsl-greedy/ and https://greedy.readthedocs.io/en/latest/reference.html#greedy-usage
28+
29+
Args:
30+
fixed_image_path (str): Path to the fixed image.
31+
moving_image_path (str): Path to the moving image.
32+
transformed_image_path (str): Path to the transformed image (output).
33+
matrix_path (str): Path to the transformation matrix (output). This gets overwritten if it already exists.
34+
log_file_path (Optional[str]): Path to the log file, which is not used.
35+
"""
36+
# add .txt suffix to the matrix path if it doesn't have any extension
37+
matrix_path = check_and_add_suffix(matrix_path, ".mat")
38+
39+
registor = Greedy3D()
40+
# these parameters are taken from the OG BraTS Pipeline [https://github.com/CBICA/CaPTk/blob/master/src/applications/BraTSPipeline.cxx]
41+
command_to_run = f"-i {fixed_image_path} {moving_image_path} -o {matrix_path} -a -dof 6 -m NMI -n 100x50x5 -ia-image-centers"
42+
43+
if log_file_path is not None:
44+
with open(log_file_path, "a+") as f:
45+
with contextlib.redirect_stdout(f):
46+
registor.execute(command_to_run)
47+
else:
48+
registor.execute(command_to_run)
49+
50+
self.transform(
51+
fixed_image_path, moving_image_path, transformed_image_path, matrix_path
52+
)
53+
54+
def transform(
55+
self,
56+
fixed_image_path: str,
57+
moving_image_path: str,
58+
transformed_image_path: str,
59+
matrix_path: str,
60+
interpolator: Optional[str] = "LINEAR",
61+
log_file_path: Optional[str] = None,
62+
) -> None:
63+
"""
64+
Apply a transformation using greedy.
65+
66+
Args:
67+
fixed_image_path (str): Path to the fixed image.
68+
moving_image_path (str): Path to the moving image.
69+
transformed_image_path (str): Path to the transformed image (output).
70+
matrix_path (str): Path to the transformation matrix (output). This gets overwritten if it already exists.
71+
interpolator (Optional[str]): The interpolator to use; one of NN, LINEAR or LABEL.
72+
log_file_path (Optional[str]): Path to the log file, which is not used.
73+
"""
74+
registor = Greedy3D()
75+
interpolator_upper = interpolator.upper()
76+
if "LABEL" in interpolator_upper:
77+
interpolator_upper += " 0.3vox"
78+
79+
matrix_path = check_and_add_suffix(matrix_path, ".mat")
80+
81+
if not os.path.exists(matrix_path):
82+
self.register(
83+
fixed_image_path,
84+
moving_image_path,
85+
transformed_image_path,
86+
matrix_path,
87+
log_file_path,
88+
)
89+
90+
command_to_run = f"-rf {fixed_image_path} -rm {moving_image_path} {transformed_image_path} -r {matrix_path} -ri {interpolator_upper}"
91+
if log_file_path is not None:
92+
with open(log_file_path, "a+") as f:
93+
with contextlib.redirect_stdout(f):
94+
registor.execute(command_to_run)
95+
else:
96+
registor.execute(command_to_run)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .generic import check_and_add_suffix
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
def check_and_add_suffix(filename: str, suffix: str) -> str:
2+
"""
3+
Adds a suffix to the filename if it doesn't already have it.
4+
5+
Parameters:
6+
filename (str): The filename to check and potentially modify.
7+
suffix (str): The suffix to add to the filename.
8+
9+
Returns:
10+
str: The filename with the suffix added if needed.
11+
"""
12+
filename_copy = filename
13+
if not filename_copy.endswith(suffix):
14+
filename_copy += suffix
15+
return filename_copy

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,15 @@ rich = "^13.6.0"
7070

7171
# optional registration backends
7272
ereg = { version = "^0.0.10", optional = true }
73+
itk-elastix = { version = "^0.20.0", optional = true }
74+
picsl_greedy = { version = "^0.0.6", optional = true }
7375

7476

7577
[tool.poetry.extras]
76-
all = ["ereg"]
78+
all = ["ereg", "itk-elastix", "picsl_greedy"]
7779
ereg = ["ereg"]
80+
itk-elastix = ["itk-elastix"]
81+
picsl_greedy = ["picsl_greedy"]
7882

7983

8084
[tool.poetry.group.dev.dependencies]

tests/test_registrators.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from brainles_preprocessing.registration.ANTs.ANTs import ANTsRegistrator
44
from brainles_preprocessing.registration.eReg.eReg import eRegRegistrator
55
from brainles_preprocessing.registration.niftyreg.niftyreg import NiftyRegRegistrator
6+
from brainles_preprocessing.registration.elastix.elastix import ElastixRegistrator
7+
from brainles_preprocessing.registration.greedy.greedy import GreedyRegistrator
68

79
import unittest
810

@@ -29,3 +31,19 @@ def get_registrator(self):
2931

3032
def get_method_and_extension(self):
3133
return "ereg", "mat"
34+
35+
36+
class TestElastixRegistrator(RegistratorBase, unittest.TestCase):
37+
def get_registrator(self):
38+
return ElastixRegistrator()
39+
40+
def get_method_and_extension(self):
41+
return "elastix", "txt"
42+
43+
44+
class TestGreedyRegistrator(RegistratorBase, unittest.TestCase):
45+
def get_registrator(self):
46+
return GreedyRegistrator()
47+
48+
def get_method_and_extension(self):
49+
return "greedy", "mat"

0 commit comments

Comments
 (0)