Skip to content

Commit 56c92a0

Browse files
committed
REF: Generalize bvals and uptake value iterators into a single iterator
Generalize bvals and uptake value iterators into a single iterator that is able to traverse the values in ascending or descending order depending on whether it is provided a list of b-values (DWI) or uptake values (PET). Follow-up to commit a1310e6.
1 parent 09df9e8 commit 56c92a0

File tree

3 files changed

+85
-93
lines changed

3 files changed

+85
-93
lines changed

src/nifreeze/utils/iterators.py

Lines changed: 44 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import random
2626
from itertools import chain, zip_longest
27-
from typing import Iterator
27+
from typing import Iterator, Sequence
2828

2929
DEFAULT_ROUND_DECIMALS = 2
3030
"""Round decimals to use when comparing values to be sorted for iteration purposes."""
@@ -172,14 +172,14 @@ def random_iterator(**kwargs) -> Iterator[int]:
172172

173173

174174
def _value_iterator(
175-
values: list, ascending: bool, round_decimals: int = DEFAULT_ROUND_DECIMALS
175+
values: Sequence[float], ascending: bool, round_decimals: int = DEFAULT_ROUND_DECIMALS
176176
) -> Iterator[int]:
177177
"""
178178
Traverse the given values in ascending or descenting order.
179179
180180
Parameters
181181
----------
182-
values : :obj:`list`
182+
values : :obj:`Sequence`
183183
List of values to traverse.
184184
ascending : :obj:`bool`
185185
If ``True``, traverse in ascending order; traverse in descending order
@@ -211,70 +211,57 @@ def _value_iterator(
211211
return (index[1] for index in indexed_vals)
212212

213213

214-
def bvalue_iterator(*_, **kwargs) -> Iterator[int]:
215-
"""
216-
Traverse the volumes in a DWI dataset by increasing b-value.
217-
218-
Parameters
219-
----------
220-
bvals : :obj:`list`
221-
List of b-values corresponding to all orientations of the dataset.
222-
Please note that ``bvals`` is a keyword argument and MUST be provided
223-
to generate the volume sequence.
224-
225-
Yields
226-
------
227-
:obj:`int`
228-
The next index.
214+
def monotonic_value_iterator(*_, **kwargs) -> Iterator[int]:
215+
try:
216+
feature = next(k for k in (BVALS_KWARG, UPTAKE_KWARG) if kwargs.get(k) is not None)
217+
except StopIteration:
218+
raise TypeError(KWARG_ERROR_MSG.format(kwarg=f"{BVALS_KWARG} or {UPTAKE_KWARG}"))
229219

230-
Examples
231-
--------
232-
>>> list(bvalue_iterator(bvals=[0.0, 0.0, 1000.0, 1000.0, 700.0, 700.0, 2000.0, 2000.0, 0.0]))
233-
[0, 1, 8, 4, 5, 2, 3, 6, 7]
234-
235-
"""
236-
bvals = kwargs.pop(BVALS_KWARG, None)
237-
if bvals is None:
238-
raise TypeError(KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG))
220+
ascending = feature == BVALS_KWARG
221+
values = kwargs[feature]
239222
return _value_iterator(
240-
bvals, ascending=True, round_decimals=kwargs.pop("round_decimals", DEFAULT_ROUND_DECIMALS)
223+
values,
224+
ascending=ascending,
225+
round_decimals=kwargs.get("round_decimals", DEFAULT_ROUND_DECIMALS),
241226
)
242227

243228

244-
def uptake_iterator(*_, **kwargs) -> Iterator[int]:
245-
"""
246-
Traverse the volumes in a PET dataset by decreasing uptake value.
229+
monotonic_value_iterator.__doc__ = f"""
230+
Traverse the volumes by increasing b-value in a DWI dataset or by decreasing
231+
uptake value in a PET dataset.
247232
248-
This function assumes that each uptake value corresponds to a single volume,
249-
and that this value summarizes the uptake of the volume in a meaningful way,
250-
e.g. a mean value across the entire volume.
233+
This function requires ``bvals`` or ``uptake`` be a keyword argument to generate
234+
the volume sequence. The b-values are assumed to all orientations in a DWI
235+
dataset, and uptake uptake values correspond to all volumes in a PET dataset.
251236
252-
Parameters
253-
----------
254-
uptake : :obj:`list`
255-
List of uptake values corresponding to all volumes of the dataset.
256-
Please note that ``uptake`` is a keyword argument and MUST be provided
257-
to generate the volume sequence.
237+
It is assumed that each uptake value corresponds to a single volume, and that
238+
this value summarizes the uptake of the volume in a meaningful way, e.g. a mean
239+
value across the entire volume.
258240
259-
Yields
260-
------
261-
:obj:`int`
262-
The next index.
241+
Other Parameters
242+
----------------
243+
{SIZE_KEYS_DOC}
263244
264-
Examples
265-
--------
266-
>>> list(uptake_iterator(uptake=[-1.23, 1.06, 1.02, 1.38, -1.46, -1.12, -1.19, 1.24, 1.05]))
267-
[3, 7, 1, 8, 2, 5, 6, 0, 4]
245+
Notes
246+
-----
247+
Only one of the above keyword arguments may be provided at a time. If ``size``
248+
is given, all other size-related keyword arguments will be ignored. If ``size``
249+
is not provided, the function will attempt to infer the number of volumes from
250+
the length or value of the provided keyword argument. If more than one such
251+
keyword is provided, a :exc:`ValueError` will be raised.
268252
269-
"""
270-
uptake = kwargs.pop(UPTAKE_KWARG, None)
271-
if uptake is None:
272-
raise TypeError(KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG))
273-
return _value_iterator(
274-
uptake,
275-
ascending=False,
276-
round_decimals=kwargs.pop("round_decimals", DEFAULT_ROUND_DECIMALS),
277-
)
253+
Yields
254+
------
255+
:obj:`int`
256+
The next index.
257+
258+
Examples
259+
--------
260+
>>> list(monotonic_value_iterator(bvals=[0.0, 0.0, 1000.0, 1000.0, 700.0, 700.0, 2000.0, 2000.0, 0.0]))
261+
[0, 1, 8, 4, 5, 2, 3, 6, 7]
262+
>>> list(monotonic_value_iterator(uptake=[-1.23, 1.06, 1.02, 1.38, -1.46, -1.12, -1.19, 1.24, 1.05]))
263+
[3, 7, 1, 8, 2, 5, 6, 0, 4]
264+
"""
278265

279266

280267
def centralsym_iterator(**kwargs) -> Iterator[int]:

test/test_estimator.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def mock_iterator(*_, **kwargs):
133133
("random", iterators.random_iterator, "pet"),
134134
("centralsym", iterators.centralsym_iterator, "dwi"),
135135
("centralsym", iterators.centralsym_iterator, "pet"),
136-
("bvalue", iterators.bvalue_iterator, "dwi"),
137-
("uptake", iterators.uptake_iterator, "pet"),
136+
("monotonic_value", iterators.monotonic_value_iterator, "dwi"),
137+
("monotonic_value", iterators.monotonic_value_iterator, "pet"),
138138
],
139139
)
140140
def test_estimator_iterator_index_match(
@@ -206,10 +206,13 @@ class DummyXForm:
206206
return
207207
elif strategy == "centralsym":
208208
expected_indices = list(iterator_func(size=n_vols))
209-
elif strategy == "bvalue":
210-
expected_indices = list(iterator_func(bvals=bvals))
211-
elif strategy == "uptake":
212-
expected_indices = list(iterator_func(uptake=uptake))
209+
elif strategy == "monotonic_value":
210+
if modality == "dwi":
211+
expected_indices = list(iterator_func(bvals=bvals, ascending=True))
212+
elif modality == "pet":
213+
expected_indices = list(iterator_func(uptake=uptake, ascending=False))
214+
else:
215+
raise NotImplementedError(f"Modality {modality} not implemented")
213216
else:
214217
raise ValueError(f"Unknown strategy {strategy}")
215218

test/test_iterators.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@
3131
KWARG_ERROR_MSG,
3232
UPTAKE_KWARG,
3333
_value_iterator,
34-
bvalue_iterator,
3534
centralsym_iterator,
3635
linear_iterator,
36+
monotonic_value_iterator,
3737
random_iterator,
38-
uptake_iterator,
3938
)
4039

4140

@@ -144,43 +143,46 @@ def test_centralsym_iterator(kwargs, expected):
144143
assert list(centralsym_iterator(**kwargs)) == expected
145144

146145

147-
def test_bvalue_iterator_error():
148-
with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=BVALS_KWARG)):
149-
list(bvalue_iterator())
150-
151-
152146
@pytest.mark.parametrize(
153-
"bvals, expected",
147+
"kwargs",
154148
[
155-
([0, 700, 1200], [0, 1, 2]),
156-
([0, 0, 1000, 700], [0, 1, 3, 2]),
157-
([0, 1000, 1500, 700, 2000], [0, 3, 1, 2, 4]),
149+
{},
150+
{"bvals": None},
151+
{"uptake": None},
152+
{"bvals": None, "uptake": None},
158153
],
159154
)
160-
def test_bvalue_iterator(bvals, expected):
161-
obtained = list(bvalue_iterator(bvals=bvals))
162-
assert set(obtained) == set(range(len(bvals)))
163-
# Should be ordered by increasing bvalue
164-
sorted_bvals = [bvals[i] for i in obtained]
165-
assert sorted_bvals == sorted(bvals)
155+
def test_monotonic_value_iterator_error(kwargs):
156+
with pytest.raises(
157+
TypeError, match=KWARG_ERROR_MSG.format(kwarg=f"{BVALS_KWARG} or {UPTAKE_KWARG}")
158+
):
159+
monotonic_value_iterator(**kwargs)
160+
166161

162+
def test_monotonic_value_iterator_sorting_preference():
163+
result = list(monotonic_value_iterator(bvals=[700, 1000], uptake=[0.14, 0.23, 0.47]))
164+
assert result == [0, 1]
167165

168-
def test_uptake_iterator_error():
169-
with pytest.raises(TypeError, match=KWARG_ERROR_MSG.format(kwarg=UPTAKE_KWARG)):
170-
list(uptake_iterator())
166+
result = list(monotonic_value_iterator(bvals=None, uptake=[0.14, 0.23, 0.47]))
167+
assert result == [2, 1, 0]
171168

172169

173170
@pytest.mark.parametrize(
174-
"uptake, expected",
171+
"feature, values, expected",
175172
[
176-
([0.3, 0.2, 0.1], [0, 1, 2]),
177-
([0.2, 0.1, 0.3], [2, 1, 0]),
178-
([-1.02, 1.16, -0.56, 0.43], [1, 3, 2, 0]),
173+
("bvals", [0, 700, 1200], [0, 1, 2]),
174+
("bvals", [0, 0, 1000, 700], [0, 1, 3, 2]),
175+
("bvals", [0, 1000, 1500, 700, 2000], [0, 3, 1, 2, 4]),
176+
("uptake", [0.3, 0.2, 0.1], [0, 1, 2]),
177+
("uptake", [0.2, 0.1, 0.3], [2, 1, 0]),
178+
("uptake", [-1.02, 1.16, -0.56, 0.43], [1, 3, 2, 0]),
179179
],
180180
)
181-
def test_uptake_iterator_valid(uptake, expected):
182-
obtained = list(uptake_iterator(uptake=uptake))
183-
assert set(obtained) == set(range(len(uptake)))
184-
# Should be ordered by decreasing uptake
185-
sorted_uptake = [uptake[i] for i in obtained]
186-
assert sorted_uptake == sorted(uptake, reverse=True)
181+
def test_monotonic_value_iterator(feature, values, expected):
182+
obtained = list(monotonic_value_iterator(**{feature: values}))
183+
assert set(obtained) == set(range(len(values)))
184+
# If b-values, should be ordered by increasing value; if uptake values,
185+
# should be ordered by decreasing uptake
186+
sorted_vals = [values[i] for i in obtained]
187+
reverse = True if feature == "uptake" else False
188+
assert sorted_vals == sorted(values, reverse=reverse)

0 commit comments

Comments
 (0)