Skip to content

Commit ec05d16

Browse files
committed
ENH: Add gaussian process DWI signal representation notebooks
Add gaussian process DWI signal representation notebooks: - One of the notebooks uses a simulated DWI signal. - The second notebook uses a real DWI signal.
1 parent 6b6ba70 commit ec05d16

File tree

2 files changed

+431
-0
lines changed

2 files changed

+431
-0
lines changed

docs/notebooks/dwi_gp.ipynb

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
{
2+
"cells": [
3+
{
4+
"metadata": {},
5+
"cell_type": "markdown",
6+
"source": "Gaussian process notebook",
7+
"id": "486923b289155658"
8+
},
9+
{
10+
"metadata": {},
11+
"cell_type": "code",
12+
"source": [
13+
"import tempfile\n",
14+
"from pathlib import Path\n",
15+
"\n",
16+
"import numpy as np\n",
17+
"from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel\n",
18+
"\n",
19+
"from nifreeze.model.dmri import GPModel\n",
20+
"from nifreeze.data.dmri import DWI\n",
21+
"from nifreeze.data.splitting import lovo_split\n",
22+
"\n",
23+
"datadir = Path(\"../../test\") # Adapt to your local path or download to a temp location using wget\n",
24+
"\n",
25+
"kernel = DotProduct() + WhiteKernel()\n",
26+
"\n",
27+
"dwi = DWI.from_filename(datadir / \"dwi.h5\")\n",
28+
"\n",
29+
"_dwi_data = dwi.dataobj\n",
30+
"# Use a subset of the data for now to see that something is written to the\n",
31+
"# output\n",
32+
"# bvecs = dwi.gradients[:3, :].T\n",
33+
"bvecs = dwi.gradients[:3, 10:13].T # b0 values have already been masked\n",
34+
"# bvals = dwi.gradients[3:, 10:13].T # Only for inspection purposes: [[1005.], [1000.], [ 995.]]\n",
35+
"dwi_data = _dwi_data[60:63, 60:64, 40:45, 10:13]\n",
36+
"\n",
37+
"# ToDo\n",
38+
"# Provide proper values/estimates for these\n",
39+
"a = 1\n",
40+
"h = 1 # should be a NIfTI image\n",
41+
"\n",
42+
"# ToDo\n",
43+
"# Check if this should be nifreeze.model.gpr.EddyMotionGPR\n",
44+
"# Check the overlap/rework to distinguish from https://github.com/nipreps/eddymotion/blob/main/docs/notebooks/dwi_gp_estimation.ipynb\n",
45+
"num_iterations = 5\n",
46+
"gp = GPModel(\n",
47+
" dwi=dwi, a=a, h=h, kernel=kernel, num_iterations=num_iterations\n",
48+
")\n",
49+
"indices = list(range(bvecs.shape[0]))\n",
50+
"# ToDo\n",
51+
"# This should be done within the GP model class\n",
52+
"# Apply lovo strategy properly\n",
53+
"# Vectorize and parallelize\n",
54+
"result_mean = np.zeros_like(dwi_data)\n",
55+
"result_stddev = np.zeros_like(dwi_data)\n",
56+
"for idx in indices:\n",
57+
" lovo_idx = np.ones(len(indices), dtype=bool)\n",
58+
" lovo_idx[idx] = False\n",
59+
" X = bvecs[lovo_idx]\n",
60+
" for i in range(dwi_data.shape[0]):\n",
61+
" for j in range(dwi_data.shape[1]):\n",
62+
" for k in range(dwi_data.shape[2]):\n",
63+
" # ToDo\n",
64+
" # Use a mask to avoid traversing background data\n",
65+
" y = dwi_data[i, j, k, lovo_idx]\n",
66+
" gp.fit(X, y)\n",
67+
" pred_mean, pred_stddev = gp.predict(\n",
68+
" bvecs[idx, :][np.newaxis]\n",
69+
" ) # Can take multiple values X[:2, :]\n",
70+
" result_mean[i, j, k, idx] = pred_mean.item()\n",
71+
" result_stddev[i, j, k, idx] = pred_stddev.item()"
72+
],
73+
"id": "da2274009534db61",
74+
"outputs": [],
75+
"execution_count": null
76+
},
77+
{
78+
"metadata": {},
79+
"cell_type": "markdown",
80+
"source": "Plot the data",
81+
"id": "77e77cd4c73409d3"
82+
},
83+
{
84+
"metadata": {},
85+
"cell_type": "code",
86+
"source": [
87+
"from matplotlib import pyplot as plt \n",
88+
"%matplotlib inline\n",
89+
"\n",
90+
"s = dwi_data[1, 1, 2, :]\n",
91+
"s_hat_mean = result_mean[1, 1, 2, :]\n",
92+
"s_hat_stddev = result_stddev[1, 1, 2, :]\n",
93+
"x = np.asarray(indices)\n",
94+
"\n",
95+
"fig, ax = plt.subplots()\n",
96+
"ax.plot(x, s_hat_mean, c=\"orange\", label=\"predicted\")\n",
97+
"plt.fill_between(\n",
98+
" x.ravel(),\n",
99+
" s_hat_mean - 1.96 * s_hat_stddev,\n",
100+
" s_hat_mean + 1.96 * s_hat_stddev,\n",
101+
" alpha=0.5,\n",
102+
" color=\"orange\",\n",
103+
" label=r\"95% confidence interval\",\n",
104+
")\n",
105+
"plt.scatter(x, s, c=\"b\", label=\"ground truth\")\n",
106+
"ax.set_xlabel(\"bvec indices\")\n",
107+
"ax.set_ylabel(\"signal\")\n",
108+
"ax.legend()\n",
109+
"plt.title(\"Gaussian process regression on dataset\")\n",
110+
"\n",
111+
"plt.show()"
112+
],
113+
"id": "4e51f22890fb045a",
114+
"outputs": [],
115+
"execution_count": null
116+
},
117+
{
118+
"metadata": {},
119+
"cell_type": "markdown",
120+
"source": [
121+
"Plot the DWI signal for a given voxel\n",
122+
"Compute the DWI signal value wrt the b0 (how much larger/smaller is and add that delta to the unit sphere?) for each bvec direction and plot that?"
123+
],
124+
"id": "694a4c075457425d"
125+
},
126+
{
127+
"metadata": {},
128+
"cell_type": "code",
129+
"source": [
130+
"# from mpl_toolkits.mplot3d import Axes3D\n",
131+
"# fig, ax = plt.subplots()\n",
132+
"# ax = fig.add_subplot(111, projection='3d')\n",
133+
"# plt.scatter(xx, yy, zz)"
134+
],
135+
"id": "bb7d2aef53ac99f0",
136+
"outputs": [],
137+
"execution_count": null
138+
},
139+
{
140+
"metadata": {},
141+
"cell_type": "markdown",
142+
"source": "Plot the DWI signal brain data\n",
143+
"id": "62d7bc609b65c7cf"
144+
},
145+
{
146+
"metadata": {},
147+
"cell_type": "code",
148+
"source": "# plot_dwi(dmri_dataset.dataobj, dmri_dataset.affine, gradient=data_test[1], black_bg=True)",
149+
"id": "edb0e9d255516e38",
150+
"outputs": [],
151+
"execution_count": null
152+
},
153+
{
154+
"metadata": {},
155+
"cell_type": "markdown",
156+
"source": "Plot the predicted DWI signal",
157+
"id": "1a52e2450fc61dc6"
158+
},
159+
{
160+
"metadata": {},
161+
"cell_type": "code",
162+
"source": "# plot_dwi(predicted, dmri_dataset.affine, gradient=data_test[1], black_bg=True);",
163+
"id": "66150cf337b395e0",
164+
"outputs": [],
165+
"execution_count": null
166+
}
167+
],
168+
"metadata": {
169+
"kernelspec": {
170+
"display_name": "Python 3",
171+
"language": "python",
172+
"name": "python3"
173+
},
174+
"language_info": {
175+
"codemirror_mode": {
176+
"name": "ipython",
177+
"version": 2
178+
},
179+
"file_extension": ".py",
180+
"mimetype": "text/x-python",
181+
"name": "python",
182+
"nbconvert_exporter": "python",
183+
"pygments_lexer": "ipython2",
184+
"version": "2.7.6"
185+
}
186+
},
187+
"nbformat": 4,
188+
"nbformat_minor": 5
189+
}

0 commit comments

Comments
 (0)