Skip to content

Commit ae556f8

Browse files
committed
enh: update LOVO splitter to new dataset indexed access
Leverage the new ``__getitem__`` interface, which provides uniform indexed access across modalities. Requires: #52. Related-To: #19. Resolves: #20.
1 parent 0dadde7 commit ae556f8

File tree

1 file changed

+16
-45
lines changed

1 file changed

+16
-45
lines changed

src/nifreeze/data/splitting.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,65 +22,36 @@
2222
#
2323
"""Data splitting helpers."""
2424

25-
from pathlib import Path
25+
from __future__ import annotations
26+
27+
from typing import Any
2628

27-
import h5py
2829
import numpy as np
2930

31+
from nifreeze.data.base import BaseDataset
32+
3033

31-
def lovo_split(dataset, index, with_b0=False):
34+
def lovo_split(dataset: BaseDataset, index: int) -> tuple[Any, Any]:
3235
"""
3336
Produce one fold of LOVO (leave-one-volume-out).
3437
3538
Parameters
3639
----------
37-
dataset : :obj:`nifreeze.data.dmri.DWI`
38-
DWI object
40+
dataset : :obj:`nifreeze.data.base.BaseDataset`
41+
Dataset object.
3942
index : :obj:`int`
40-
Index of the DWI orientation to be left out in this fold.
43+
Index of the volume to be left out in this fold.
4144
4245
Returns
4346
-------
44-
(train_data, train_gradients) : :obj:`tuple`
45-
Training DWI and corresponding gradients.
46-
Training data/gradients come **from the updated dataset**.
47-
(test_data, test_gradients) :obj:`tuple`
48-
Test 3D map (one DWI orientation) and corresponding b-vector/value.
49-
The test data/gradient come **from the original dataset**.
47+
:obj:`tuple` of :obj:`tuple`
48+
A tuple of two elements, the first element being the components
49+
of the *train* data (including the data themselves and other metadata
50+
such as gradients for dMRI, or frame times for PET), and the second
51+
element being the *test* data.
5052
5153
"""
52-
53-
if not Path(dataset.get_filename()).exists():
54-
dataset.to_filename(dataset.get_filename())
55-
56-
# read original DWI data & b-vector
57-
with h5py.File(dataset.get_filename(), "r") as in_file:
58-
root = in_file["/0"]
59-
data = np.asanyarray(root["dataobj"])
60-
gradients = np.asanyarray(root["gradients"])
61-
62-
# if the size of the mask does not match data, cache is stale
63-
mask = np.zeros(data.shape[-1], dtype=bool)
54+
mask = np.zeros(len(dataset), dtype=bool)
6455
mask[index] = True
6556

66-
train_data = data[..., ~mask]
67-
train_gradients = gradients[..., ~mask]
68-
test_data = data[..., mask]
69-
test_gradients = gradients[..., mask]
70-
71-
if with_b0:
72-
train_data = np.concatenate(
73-
(np.asanyarray(dataset.bzero)[..., np.newaxis], train_data),
74-
axis=-1,
75-
)
76-
b0vec = np.zeros((4, 1))
77-
b0vec[0, 0] = 1
78-
train_gradients = np.concatenate(
79-
(b0vec, train_gradients),
80-
axis=-1,
81-
)
82-
83-
return (
84-
(train_data, train_gradients),
85-
(test_data, test_gradients),
86-
)
57+
return (dataset[~mask], dataset[mask])

0 commit comments

Comments
 (0)