forked from jeffduda/GetYourBrainPipelined
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregistrationSimpleITK.py
More file actions
129 lines (90 loc) · 4.46 KB
/
registrationSimpleITK.py
File metadata and controls
129 lines (90 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
# Example registration code from:
# https://simpleitk.readthedocs.io/en/master/link_ImageRegistrationMethodDisplacement1_docs.html
import SimpleITK as sitk
import sys
import os
import logging
def command_iteration(method):
if (method.GetOptimizerIteration() == 0):
print(f"\tLevel: {method.GetCurrentLevel()}")
print(f"\tScales: {method.GetOptimizerScales()}")
print(f"#{method.GetOptimizerIteration()}")
print(f"\tMetric Value: {method.GetMetricValue():10.5f}")
print(f"\tLearningRate: {method.GetOptimizerLearningRate():10.5f}")
if (method.GetOptimizerConvergenceValue() != sys.float_info.max):
print(f"\tConvergence Value: {method.GetOptimizerConvergenceValue():.5e}")
def command_multiresolution_iteration(method):
print(f"\tStop Condition: {method.GetOptimizerStopConditionDescription()}")
print("============= Resolution Change =============")
#if len(sys.argv) < 4:
# print("Usage:", sys.argv[0], "<fixedImageFilter> <movingImageFile>", "<outputTransformFile>")
# sys.exit(1)
iDir = os.getenv("INPUT_DIR")
oDir = os.getenv("OUTPUT_DIR")
logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
if len(sys.argv) > 3:
log.info("Starting SimpleITK registration")
movingFile = os.path.join(iDir, sys.argv[1])
fixedFile = os.path.join(iDir, sys.argv[2])
warpedFile = os.path.join(oDir, sys.argv[3])
log.info("Load inputs")
fixed = sitk.ReadImage(fixedFile, sitk.sitkFloat32)
moving = sitk.ReadImage(movingFile, sitk.sitkFloat32)
initialTx = sitk.CenteredTransformInitializer(fixed, moving, sitk.AffineTransform(fixed.GetDimension()))
R = sitk.ImageRegistrationMethod()
R.SetShrinkFactorsPerLevel([3, 2, 1])
R.SetSmoothingSigmasPerLevel([2, 1, 1])
R.SetMetricAsJointHistogramMutualInformation(20)
R.MetricUseFixedImageGradientFilterOff()
R.SetOptimizerAsGradientDescent(learningRate=1.0,numberOfIterations=100,estimateLearningRate=R.EachIteration)
R.SetOptimizerScalesFromPhysicalShift()
R.SetInitialTransform(initialTx)
R.SetInterpolator(sitk.sitkLinear)
#R.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(R))
#R.AddCommand(sitk.sitkMultiResolutionIterationEvent,lambda: command_multiresolution_iteration(R))
log.info("Start registration")
outTx1 = R.Execute(fixed, moving)
#print("-------")
#print(outTx1)
#print(f"Optimizer stop condition: {R.GetOptimizerStopConditionDescription()}")
#print(f" Iteration: {R.GetOptimizerIteration()}")
#print(f" Metric value: {R.GetMetricValue()}")
displacementField = sitk.Image(fixed.GetSize(), sitk.sitkVectorFloat64)
displacementField.CopyInformation(fixed)
displacementTx = sitk.DisplacementFieldTransform(displacementField)
del displacementField
displacementTx.SetSmoothingGaussianOnUpdate(varianceForUpdateField=0.0,varianceForTotalField=1.5)
R.SetMovingInitialTransform(outTx1)
R.SetInitialTransform(displacementTx, inPlace=True)
R.SetMetricAsANTSNeighborhoodCorrelation(4)
R.MetricUseFixedImageGradientFilterOff()
R.SetShrinkFactorsPerLevel([3, 2, 1])
R.SetSmoothingSigmasPerLevel([2, 1, 1])
R.SetOptimizerScalesFromPhysicalShift()
R.SetOptimizerAsGradientDescent(learningRate=1,numberOfIterations=300,estimateLearningRate=R.EachIteration)
R.Execute(fixed, moving)
#print("-------")
#print(displacementTx)
#print(f"Optimizer stop condition: {R.GetOptimizerStopConditionDescription()}")
#print(f" Iteration: {R.GetOptimizerIteration()}")
#print(f" Metric value: {R.GetMetricValue()}")
compositeTx = sitk.CompositeTransform([outTx1, displacementTx])
#sitk.WriteTransform(compositeTx, sys.argv[3])
if ("SITK_NOSHOW" not in os.environ):
#sitk.Show(displacementTx.GetDisplacementField(), "Displacement Field")
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed)
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(compositeTx)
out = resampler.Execute(moving)
writer = sitk.ImageFileWriter()
writer.SetFileName(warpedFile)
writer.Execute(out)
simg1 = sitk.Cast(sitk.RescaleIntensity(fixed), sitk.sitkUInt8)
simg2 = sitk.Cast(sitk.RescaleIntensity(out), sitk.sitkUInt8)
cimg = sitk.Compose(simg1, simg2, simg1 // 2. + simg2 // 2.)
#sitk.Show(cimg, "ImageRegistration1 Composition")