|
22 | 22 | # |
23 | 23 | """Data splitting helpers.""" |
24 | 24 |
|
25 | | -from pathlib import Path |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +from typing import Any |
26 | 28 |
|
27 | | -import h5py |
28 | 29 | import numpy as np |
29 | 30 |
|
| 31 | +from nifreeze.data.base import BaseDataset |
| 32 | + |
30 | 33 |
|
31 | | -def lovo_split(dataset, index, with_b0=False): |
| 34 | +def lovo_split(dataset: BaseDataset, index: int) -> tuple[Any, Any]: |
32 | 35 | """ |
33 | 36 | Produce one fold of LOVO (leave-one-volume-out). |
34 | 37 |
|
35 | 38 | Parameters |
36 | 39 | ---------- |
37 | | - dataset : :obj:`nifreeze.data.dmri.DWI` |
38 | | - DWI object |
| 40 | + dataset : :obj:`nifreeze.data.base.BaseDataset` |
| 41 | + Dataset object. |
39 | 42 | index : :obj:`int` |
40 | | - Index of the DWI orientation to be left out in this fold. |
| 43 | + Index of the volume to be left out in this fold. |
41 | 44 |
|
42 | 45 | Returns |
43 | 46 | ------- |
44 | | - (train_data, train_gradients) : :obj:`tuple` |
45 | | - Training DWI and corresponding gradients. |
46 | | - Training data/gradients come **from the updated dataset**. |
47 | | - (test_data, test_gradients) :obj:`tuple` |
48 | | - Test 3D map (one DWI orientation) and corresponding b-vector/value. |
49 | | - The test data/gradient come **from the original dataset**. |
| 47 | + :obj:`tuple` of :obj:`tuple` |
| 48 | + A tuple of two elements, the first element being the components |
| 49 | + of the *train* data (including the data themselves and other metadata |
| 50 | + such as gradients for dMRI, or frame times for PET), and the second |
| 51 | + element being the *test* data. |
50 | 52 |
|
51 | 53 | """ |
52 | | - |
53 | | - if not Path(dataset.get_filename()).exists(): |
54 | | - dataset.to_filename(dataset.get_filename()) |
55 | | - |
56 | | - # read original DWI data & b-vector |
57 | | - with h5py.File(dataset.get_filename(), "r") as in_file: |
58 | | - root = in_file["/0"] |
59 | | - data = np.asanyarray(root["dataobj"]) |
60 | | - gradients = np.asanyarray(root["gradients"]) |
61 | | - |
62 | | - # if the size of the mask does not match data, cache is stale |
63 | | - mask = np.zeros(data.shape[-1], dtype=bool) |
| 54 | + mask = np.zeros(len(dataset), dtype=bool) |
64 | 55 | mask[index] = True |
65 | 56 |
|
66 | | - train_data = data[..., ~mask] |
67 | | - train_gradients = gradients[..., ~mask] |
68 | | - test_data = data[..., mask] |
69 | | - test_gradients = gradients[..., mask] |
70 | | - |
71 | | - if with_b0: |
72 | | - train_data = np.concatenate( |
73 | | - (np.asanyarray(dataset.bzero)[..., np.newaxis], train_data), |
74 | | - axis=-1, |
75 | | - ) |
76 | | - b0vec = np.zeros((4, 1)) |
77 | | - b0vec[0, 0] = 1 |
78 | | - train_gradients = np.concatenate( |
79 | | - (b0vec, train_gradients), |
80 | | - axis=-1, |
81 | | - ) |
82 | | - |
83 | | - return ( |
84 | | - (train_data, train_gradients), |
85 | | - (test_data, test_gradients), |
86 | | - ) |
| 57 | + return (dataset[~mask], dataset[mask]) |
0 commit comments