Skip to content

Commit 04bb21f

Browse files
authored
Merge pull request #117 from mrava87/dev
Added new tutorial on MRI imaging
2 parents 7272e2e + 3416724 commit 04bb21f

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

testdata/brainphantom.mat

122 KB
Binary file not shown.

testdata/spiralsampling.mat

4.19 KB
Binary file not shown.

tutorials/brainmri.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
r"""
2+
MRI Imaging and Segmentation of Brain
3+
=====================================
4+
This tutorial considers the well-known problem of MRI imaging, where given the availability of a sparsely sampled
5+
KK-spectrum, one is tasked to reconstruct the underline spatial luminosity of an object under observation. In
6+
this specific case, we will be using an example from `Corona et al., 2019, Enhancing joint reconstruction and
7+
segmentation with non-convex Bregman iteration`.
8+
9+
We first consider the imaging problem defined by the following cost functuon
10+
11+
.. math::
12+
\argmin_\mathbf{x} \|\mathbf{y}-\mathbf{Ax}\|_2^2 + \alpha TV(\mathbf{x})
13+
14+
where the operator :math:`\mathbf{A}` performs a 2D-Fourier transform followed by sampling of the KK plane, :math:`\mathbf{x}`
15+
is the object of interest and :math:`\mathbf{y}` the set of available Fourier coefficients.
16+
17+
Once the model is reconstructed, we solve a second inverse problem with the aim of segmenting the retrieved object into
18+
:math:`N` classes of different luminosity.
19+
"""
20+
import numpy as np
21+
import matplotlib.pyplot as plt
22+
import pylops
23+
from scipy.io import loadmat
24+
25+
import pyproximal
26+
27+
plt.close('all')
28+
np.random.seed(10)
29+
30+
###############################################################################
31+
# Let's start by loading the data and the sampling mask
32+
mat = loadmat('../testdata/brainphantom.mat')
33+
mat1 = loadmat('../testdata/spiralsampling.mat')
34+
gt = mat['gt']
35+
seggt = mat['gt_seg']
36+
sampling = mat1['samp']
37+
sampling1 = np.fft.ifftshift(sampling)
38+
39+
fig, axs = plt.subplots(1, 3, figsize=(15, 6))
40+
axs[0].imshow(gt, cmap='gray')
41+
axs[0].axis('tight')
42+
axs[0].set_title("Object")
43+
axs[1].imshow(seggt, cmap='Accent')
44+
axs[1].axis('tight')
45+
axs[1].set_title("Segmentation")
46+
axs[2].imshow(sampling, cmap='gray')
47+
axs[2].axis('tight')
48+
axs[2].set_title("Sampling mask")
49+
plt.tight_layout()
50+
51+
###############################################################################
52+
# We can now create the MRI operator
53+
Fop = pylops.signalprocessing.FFT2D(dims=gt.shape)
54+
Rop = pylops.Restriction(gt.size, np.where(sampling1.ravel() == 1)[0],
55+
dtype=np.complex128)
56+
Dop = Rop * Fop
57+
58+
# KK spectrum
59+
GT = Fop * gt.ravel()
60+
GT = GT.reshape(gt.shape)
61+
62+
# Data (Masked KK spectrum)
63+
d = Dop * gt.ravel()
64+
65+
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
66+
axs[0].imshow(np.fft.fftshift(np.abs(GT)), vmin=0, vmax=1, cmap='gray')
67+
axs[0].axis('tight')
68+
axs[0].set_title("Spectrum")
69+
axs[1].plot(np.fft.fftshift(np.abs(d)), 'k', lw=2)
70+
axs[1].axis('tight')
71+
axs[1].set_title("Masked Spectrum")
72+
plt.tight_layout()
73+
74+
###############################################################################
75+
# Let's try now to reconstruct the object from its measurement. The simplest
76+
# approach entails simply filling the missing values in the KK spectrum with
77+
# zeros and applying inverse FFT.
78+
79+
GTzero = sampling1 * GT
80+
gtzero = (Fop.H * GTzero).real
81+
82+
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
83+
axs[0].imshow(gt, cmap='gray')
84+
axs[0].axis('tight')
85+
axs[0].set_title("True Object")
86+
axs[1].imshow(gtzero, cmap='gray')
87+
axs[1].axis('tight')
88+
axs[1].set_title("Zero-filling Object")
89+
plt.tight_layout()
90+
91+
###############################################################################
92+
# We can now do better if we introduce some prior information in the form of
93+
# TV on the solution
94+
95+
with pylops.disabled_ndarray_multiplication():
96+
sigma = 0.04
97+
l1 = pyproximal.proximal.L21(ndim=2)
98+
l2 = pyproximal.proximal.L2(Op=Dop, b=d.ravel(), niter=50, warm=True)
99+
Gop = sigma * pylops.Gradient(dims=gt.shape, edge=True, kind='forward', dtype=np.complex)
100+
101+
L = sigma ** 2 * 8
102+
tau = .99 / np.sqrt(L)
103+
mu = .99 / np.sqrt(L)
104+
105+
gtpd = pyproximal.optimization.primaldual.PrimalDual(l2, l1, Gop, x0=np.zeros(gt.size),
106+
tau=tau, mu=mu, theta=1.,
107+
niter=100, show=True)
108+
gtpd = np.real(gtpd.reshape(gt.shape))
109+
110+
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
111+
axs[0].imshow(gt, cmap='gray')
112+
axs[0].axis('tight')
113+
axs[0].set_title("True Object")
114+
axs[1].imshow(gtpd, cmap='gray')
115+
axs[1].axis('tight')
116+
axs[1].set_title("TV-reg Object")
117+
plt.tight_layout()
118+
119+
###############################################################################
120+
# Finally we segment our reconstructed model into 4 classes.
121+
122+
cl = np.array([0.01, 0.43, 0.65, 0.8])
123+
ncl = len(cl)
124+
segpd_prob, segpd = \
125+
pyproximal.optimization.segmentation.Segment(gtpd, cl, 1., 0.001,
126+
niter=10, show=True,
127+
kwargs_simplex=dict(engine='numba'))
128+
129+
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
130+
axs[0].imshow(seggt, cmap='Accent')
131+
axs[0].axis('tight')
132+
axs[0].set_title("True Classes")
133+
axs[1].imshow(segpd, cmap='Accent')
134+
axs[1].axis('tight')
135+
axs[1].set_title("Estimated Classes")
136+
plt.tight_layout()
137+
138+
fig, axs = plt.subplots(1, 4, figsize=(15, 6))
139+
for i, ax in enumerate(axs):
140+
ax.imshow(segpd_prob[:, i].reshape(gt.shape), cmap='Reds')
141+
axs[i].axis('tight')
142+
axs[i].set_title(f"Class {i}")
143+
plt.tight_layout()

0 commit comments

Comments
 (0)