1+ r"""
2+ MRI Imaging and Segmentation of Brain
3+ =====================================
4+ This tutorial considers the well-known problem of MRI imaging, where given the availability of a sparsely sampled
5+ KK-spectrum, one is tasked to reconstruct the underline spatial luminosity of an object under observation. In
6+ this specific case, we will be using an example from `Corona et al., 2019, Enhancing joint reconstruction and
7+ segmentation with non-convex Bregman iteration`.
8+
9+ We first consider the imaging problem defined by the following cost functuon
10+
11+ .. math::
12+ \argmin_\mathbf{x} \|\mathbf{y}-\mathbf{Ax}\|_2^2 + \alpha TV(\mathbf{x})
13+
14+ where the operator :math:`\mathbf{A}` performs a 2D-Fourier transform followed by sampling of the KK plane, :math:`\mathbf{x}`
15+ is the object of interest and :math:`\mathbf{y}` the set of available Fourier coefficients.
16+
17+ Once the model is reconstructed, we solve a second inverse problem with the aim of segmenting the retrieved object into
18+ :math:`N` classes of different luminosity.
19+ """
20+ import numpy as np
21+ import matplotlib .pyplot as plt
22+ import pylops
23+ from scipy .io import loadmat
24+
25+ import pyproximal
26+
27+ plt .close ('all' )
28+ np .random .seed (10 )
29+
30+ ###############################################################################
31+ # Let's start by loading the data and the sampling mask
32+ mat = loadmat ('../testdata/brainphantom.mat' )
33+ mat1 = loadmat ('../testdata/spiralsampling.mat' )
34+ gt = mat ['gt' ]
35+ seggt = mat ['gt_seg' ]
36+ sampling = mat1 ['samp' ]
37+ sampling1 = np .fft .ifftshift (sampling )
38+
39+ fig , axs = plt .subplots (1 , 3 , figsize = (15 , 6 ))
40+ axs [0 ].imshow (gt , cmap = 'gray' )
41+ axs [0 ].axis ('tight' )
42+ axs [0 ].set_title ("Object" )
43+ axs [1 ].imshow (seggt , cmap = 'Accent' )
44+ axs [1 ].axis ('tight' )
45+ axs [1 ].set_title ("Segmentation" )
46+ axs [2 ].imshow (sampling , cmap = 'gray' )
47+ axs [2 ].axis ('tight' )
48+ axs [2 ].set_title ("Sampling mask" )
49+ plt .tight_layout ()
50+
51+ ###############################################################################
52+ # We can now create the MRI operator
53+ Fop = pylops .signalprocessing .FFT2D (dims = gt .shape )
54+ Rop = pylops .Restriction (gt .size , np .where (sampling1 .ravel () == 1 )[0 ],
55+ dtype = np .complex128 )
56+ Dop = Rop * Fop
57+
58+ # KK spectrum
59+ GT = Fop * gt .ravel ()
60+ GT = GT .reshape (gt .shape )
61+
62+ # Data (Masked KK spectrum)
63+ d = Dop * gt .ravel ()
64+
65+ fig , axs = plt .subplots (1 , 2 , figsize = (12 , 6 ))
66+ axs [0 ].imshow (np .fft .fftshift (np .abs (GT )), vmin = 0 , vmax = 1 , cmap = 'gray' )
67+ axs [0 ].axis ('tight' )
68+ axs [0 ].set_title ("Spectrum" )
69+ axs [1 ].plot (np .fft .fftshift (np .abs (d )), 'k' , lw = 2 )
70+ axs [1 ].axis ('tight' )
71+ axs [1 ].set_title ("Masked Spectrum" )
72+ plt .tight_layout ()
73+
74+ ###############################################################################
75+ # Let's try now to reconstruct the object from its measurement. The simplest
76+ # approach entails simply filling the missing values in the KK spectrum with
77+ # zeros and applying inverse FFT.
78+
79+ GTzero = sampling1 * GT
80+ gtzero = (Fop .H * GTzero ).real
81+
82+ fig , axs = plt .subplots (1 , 2 , figsize = (12 , 6 ))
83+ axs [0 ].imshow (gt , cmap = 'gray' )
84+ axs [0 ].axis ('tight' )
85+ axs [0 ].set_title ("True Object" )
86+ axs [1 ].imshow (gtzero , cmap = 'gray' )
87+ axs [1 ].axis ('tight' )
88+ axs [1 ].set_title ("Zero-filling Object" )
89+ plt .tight_layout ()
90+
91+ ###############################################################################
92+ # We can now do better if we introduce some prior information in the form of
93+ # TV on the solution
94+
95+ with pylops .disabled_ndarray_multiplication ():
96+ sigma = 0.04
97+ l1 = pyproximal .proximal .L21 (ndim = 2 )
98+ l2 = pyproximal .proximal .L2 (Op = Dop , b = d .ravel (), niter = 50 , warm = True )
99+ Gop = sigma * pylops .Gradient (dims = gt .shape , edge = True , kind = 'forward' , dtype = np .complex )
100+
101+ L = sigma ** 2 * 8
102+ tau = .99 / np .sqrt (L )
103+ mu = .99 / np .sqrt (L )
104+
105+ gtpd = pyproximal .optimization .primaldual .PrimalDual (l2 , l1 , Gop , x0 = np .zeros (gt .size ),
106+ tau = tau , mu = mu , theta = 1. ,
107+ niter = 100 , show = True )
108+ gtpd = np .real (gtpd .reshape (gt .shape ))
109+
110+ fig , axs = plt .subplots (1 , 2 , figsize = (12 , 6 ))
111+ axs [0 ].imshow (gt , cmap = 'gray' )
112+ axs [0 ].axis ('tight' )
113+ axs [0 ].set_title ("True Object" )
114+ axs [1 ].imshow (gtpd , cmap = 'gray' )
115+ axs [1 ].axis ('tight' )
116+ axs [1 ].set_title ("TV-reg Object" )
117+ plt .tight_layout ()
118+
119+ ###############################################################################
120+ # Finally we segment our reconstructed model into 4 classes.
121+
122+ cl = np .array ([0.01 , 0.43 , 0.65 , 0.8 ])
123+ ncl = len (cl )
124+ segpd_prob , segpd = \
125+ pyproximal .optimization .segmentation .Segment (gtpd , cl , 1. , 0.001 ,
126+ niter = 10 , show = True ,
127+ kwargs_simplex = dict (engine = 'numba' ))
128+
129+ fig , axs = plt .subplots (1 , 2 , figsize = (12 , 6 ))
130+ axs [0 ].imshow (seggt , cmap = 'Accent' )
131+ axs [0 ].axis ('tight' )
132+ axs [0 ].set_title ("True Classes" )
133+ axs [1 ].imshow (segpd , cmap = 'Accent' )
134+ axs [1 ].axis ('tight' )
135+ axs [1 ].set_title ("Estimated Classes" )
136+ plt .tight_layout ()
137+
138+ fig , axs = plt .subplots (1 , 4 , figsize = (15 , 6 ))
139+ for i , ax in enumerate (axs ):
140+ ax .imshow (segpd_prob [:, i ].reshape (gt .shape ), cmap = 'Reds' )
141+ axs [i ].axis ('tight' )
142+ axs [i ].set_title (f"Class { i } " )
143+ plt .tight_layout ()
0 commit comments