Skip to content

Commit b1983a1

Browse files
Rebuilt plot_sinogram_profiles.py for 4D tof data (#1370)
* Rebuilt plot_sinogram_profiles for 4D tof data * Add test_generate_1d_from_4d to pytest * added release notes --------- Co-authored-by: Kris Thielemans <KrisThielemans@users.noreply.github.com>
1 parent ac0d0b3 commit b1983a1

File tree

6 files changed

+330
-74
lines changed

6 files changed

+330
-74
lines changed

.appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ build_script:
3636
- echo Using Miniconda %MINICONDA%
3737
- "set PATH=%MINICONDA%;%MINICONDA%\\Scripts;%MINICONDA%\\Library\\bin;%PATH%"
3838
# install parallelproj and Python stuff
39-
- conda install -c conda-forge -yq libparallelproj swig numpy pytest
39+
- conda install -c conda-forge -yq libparallelproj swig numpy pytest matplotlib
4040
- CALL conda.bat activate base
4141
- python --version
4242
- mkdir build

.github/workflows/build-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,10 @@ jobs:
235235
# brew install openblas
236236
# export OPENBLAS=$(brew --prefix openblas)
237237
#python -m pip install --no-cache-dir --no-binary numpy numpy # avoid the cached .whl!
238-
python -m pip install numpy pytest
238+
python -m pip install numpy pytest matplotlib
239239
;;
240240
(*)
241-
python -m pip install numpy pytest
241+
python -m pip install numpy pytest matplotlib
242242
;;
243243
esac
244244

documentation/release_6.3.htm

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ <h4>C++ tests</h4>
125125

126126
<h4>recon_test_pack</h4>
127127

128+
<h3>Changes to examples</h3>
129+
<uk>
130+
<li>
131+
Python example <code>plot_sinogram_profiles.py</code> has been renamed to <code>plot_projdata_profiles.py</code>
132+
and generalised to work with TOF dimensions etc. A small <code>pytest</code> has been added as well.
133+
<a href=https://github.com/UCL/STIR/pull/1370>PR #1370</a>
134+
</li>
135+
</ul>
136+
128137
</body>
129138

130139
</html>
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Demo to plot the profile of projection data using STIR
2+
# To run in "normal" Python, you would type the following in the command line
3+
# execfile('plot_projdata_profiles.py')
4+
# In ipython, you can use
5+
# %run plot_projdata_profiles.py
6+
# Or of course
7+
# import plot_projdata_profiles
8+
9+
# Copyright 2021 University College London
10+
# Copyright 2024 Prescient Imaging
11+
12+
# Authors: Robert Twyman
13+
14+
# This file is part of STIR.
15+
# SPDX-License-Identifier: Apache-2.0
16+
# See STIR/LICENSE.txt for details
17+
18+
from __future__ import annotations # for supporting newer typing info in old Python versions (from 3.7)
19+
20+
import argparse
21+
import sys
22+
23+
import matplotlib.pyplot as plt
24+
import numpy as np
25+
import stir
26+
import stirextra
27+
28+
PROJDATA_DIM_MAP = {
29+
0: "TOF",
30+
1: "Axial Segment",
31+
2: "View",
32+
3: "Tangential"
33+
}
34+
35+
36+
def get_projdata_from_file_as_numpy(filename: str) -> np.ndarray | None:
37+
"""
38+
Load a Projdata file and convert it to a NumPy array.
39+
Args:
40+
filename: The filename of the Projdata file to load.
41+
Returns:
42+
result: The NumPy array.
43+
"""
44+
try:
45+
projdata: stir.ProjData = stir.ProjData.read_from_file(filename)
46+
except Exception as e:
47+
print(f"Error reading file {filename}: {e}")
48+
return None
49+
50+
try:
51+
return stirextra.to_numpy(projdata)
52+
except Exception as e:
53+
print(f"Error converting to numpy: {e}")
54+
return None
55+
56+
57+
def get_projection_data_as_array(f: str | stir.ProjData) -> np.ndarray | None:
58+
"""
59+
Get the projection data from a file or object.
60+
Args:
61+
f: The file name or object to get the projection data from.
62+
Returns:
63+
result: The projection data as a NumPy array.
64+
"""
65+
# Get the input data from a file or object
66+
if isinstance(f, str):
67+
print(f"Handling:\n\t{f}")
68+
return get_projdata_from_file_as_numpy(f)
69+
70+
elif isinstance(f, stir.ProjData):
71+
try:
72+
return stirextra.to_numpy(f)
73+
except AttributeError as e:
74+
print(f"AttributeError converting to projdata to numpy.\nError message{e}")
75+
return None
76+
77+
else:
78+
print(f"Unknown type for {f=}")
79+
return None
80+
81+
82+
def compress_and_extract_1d_from_nd_array(data: np.ndarray,
83+
display_axis: int,
84+
axes_indices: list[int | None] | None = None
85+
) -> np.ndarray:
86+
"""
87+
Generate a 1D array from an n-dimensional NumPy array based on specified parameters.
88+
The display is the axis to be extracted.
89+
The axes_indices is a list of indices to extract from each dimension.
90+
If the index is None, the entire dimension is summed.
91+
If the index is not None, the data is taken from that index.
92+
Args:
93+
data: The n-dimensional NumPy array.
94+
display_axis: The index of the dimension to be treated as the horizontal component.
95+
axes_indices: A list of indices to extract from each dimension.
96+
If None, all indices, except the display axis, are summed.
97+
Returns:
98+
result: The 1D NumPy array.
99+
Exceptions:
100+
ValueError: If the data is not at least 2D.
101+
ValueError: If the number of axes indices does not match the number of dimensions.
102+
ValueError: If the indices are out of bounds.
103+
"""
104+
if data.ndim < 2:
105+
raise ValueError(f"Data must have at least 2 dimensions, not {data.ndim}D")
106+
107+
if axes_indices is None:
108+
axes_indices = [None] * data.ndim
109+
if not len(axes_indices) == data.ndim:
110+
raise ValueError(
111+
f"Number of axes indices ({len(axes_indices)}) must match the number of dimensions ({data.ndim})")
112+
113+
working_axis = 0
114+
# Check if indices are within valid range for all dimensions
115+
for data_axis, index in enumerate(axes_indices):
116+
if index is not None and not np.all(np.logical_and(index >= 0, index < data.shape[data_axis])):
117+
raise ValueError(f"Indices for axis {data_axis} are out of bounds. {index=}, {data.shape[data_axis]=}")
118+
119+
for data_axis in range(data.ndim):
120+
if display_axis == data_axis:
121+
working_axis += 1
122+
elif axes_indices[data_axis] is None:
123+
data = np.sum(data, axis=working_axis)
124+
else:
125+
data = np.take(data, axes_indices[data_axis], axis=working_axis)
126+
return data
127+
128+
129+
def plot_projdata_profiles(projection_data_list: list[stir.ProjData] | list[str],
130+
display_axis: int = 3,
131+
data_indices: list[int | None] | None = None,
132+
) -> None:
133+
"""
134+
Plots the profiles of the projection data.
135+
Compress (via sum) and extract a 1D array from a 4D array of projection data for each element of the list.
136+
Args:
137+
projection_data_list: list of projection data file names or stir.ProjData objects to load and plot.
138+
display_axis: The horizontal component of the projection data to plot.
139+
data_indices: The indices to extract from the projection data (None indices are summed).
140+
Returns:
141+
None
142+
"""
143+
144+
plt.figure()
145+
ax = plt.subplot(111)
146+
147+
for f in projection_data_list:
148+
if isinstance(f, str):
149+
label = f
150+
else:
151+
label = ""
152+
153+
projdata_npy = get_projection_data_as_array(f)
154+
if projdata_npy is None:
155+
continue
156+
157+
# Generate the 1D array
158+
try:
159+
plot_data = compress_and_extract_1d_from_nd_array(projdata_npy, display_axis, data_indices)
160+
except ValueError as e:
161+
print(f"Error generating 1D array object.\nError message: {e}")
162+
continue
163+
164+
plt.plot(plot_data, label=label)
165+
166+
if len(plt.gca().get_lines()) == 0:
167+
print("Something went wrong! No data to plot.")
168+
return
169+
170+
# Identify sum and extraction axes
171+
sum_axis = [i for i, x in enumerate(data_indices) if x is None and i != display_axis]
172+
index_axis = [i for i, x in enumerate(data_indices) if x is not None and i != display_axis]
173+
174+
# Extract labels and values for sum and extraction axes
175+
sum_axis_labels = [PROJDATA_DIM_MAP[i] for i in sum_axis]
176+
extraction_axis_labels = [PROJDATA_DIM_MAP[i] for i in index_axis]
177+
index_values = [data_indices[i] for i in index_axis]
178+
179+
# Plot title
180+
plt.title(f"Summing {sum_axis_labels} axis and extracting {extraction_axis_labels} with values {index_values}")
181+
plt.xlabel(f"{PROJDATA_DIM_MAP[display_axis]}")
182+
ax.legend()
183+
plt.show()
184+
185+
186+
if __name__ == '__main__':
187+
parser = argparse.ArgumentParser(sys.argv[0])
188+
parser.description = ("This script loads, sums axis' and plots profiles over input projection data files."
189+
"The default is to sum over all components, except the display axis."
190+
"The indices used are array based, not STIR offset based.")
191+
parser.add_argument('filenames',
192+
nargs='*',
193+
help='Projection data file names to show, can handle multiple.')
194+
parser.add_argument('--display_axis',
195+
dest="display_axis",
196+
type=int,
197+
default=3,
198+
help='The horizontal component of the projection data to plot.'
199+
'The default is -1 indicating a sum over all components. '
200+
'0: TOF, 1: axial (and segment), 2: view, 3: tangential.')
201+
parser.add_argument('--tof',
202+
dest="tof",
203+
type=int,
204+
default=None,
205+
help='The TOF value of the projection data to plot.'
206+
'The default is to sum over all TOF values.')
207+
parser.add_argument('--axial_segment',
208+
dest="axial_segment",
209+
type=int,
210+
default=None,
211+
help='The axial segment number of the projection data to plot.'
212+
'The default is to sum over all axial segments.')
213+
parser.add_argument('--view',
214+
dest="view",
215+
type=int,
216+
default=None,
217+
help='The view of the projection data to plot.'
218+
'The default is to sum over all views.')
219+
parser.add_argument('--tangential_pos',
220+
dest="tangential",
221+
type=int,
222+
default=None,
223+
help='The tangential position of the projection data to plot.'
224+
'The default is to sum over all tangential positions.')
225+
226+
args = parser.parse_args()
227+
228+
if len(args.filenames) < 1:
229+
parser.print_help()
230+
exit(0)
231+
232+
plot_projdata_profiles(projection_data_list=args.filenames,
233+
display_axis=args.display_axis,
234+
data_indices=[args.tof, args.axial_segment, args.view, args.tangential]
235+
)

examples/python/plot_sinogram_profiles.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

0 commit comments

Comments
 (0)