Skip to content

Commit 052aa0e

Browse files
committed
fix: calls to lovo_split
1 parent ae556f8 commit 052aa0e

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

src/nifreeze/estimator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ def estimate(
141141
pbar.set_description_str(
142142
f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>"
143143
)
144-
data_train, data_test = lovo_split(data, i, with_b0=True)
145-
grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}"
144+
data_train, data_test = lovo_split(data, i)
145+
grad_str = f"{i}, {data_test[-1][:3]}, b={int(data_test[-1][3])}"
146146
pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs")
147147

148148
if not single_model: # A true LOGO estimator
149149
if hasattr(data, "gradients"):
150-
kwargs["gtab"] = data_train[1]
150+
kwargs["gtab"] = data_train[-1]
151151
# Factory creates the appropriate model and pipes arguments
152152
dwmodel = ModelFactory.init(
153153
model=model,
@@ -162,7 +162,7 @@ def estimate(
162162
)
163163

164164
# generate a synthetic dw volume for the test gradient
165-
predicted = dwmodel.predict(data_test[1])
165+
predicted = dwmodel.predict(data_test[-1])
166166

167167
# prepare data for running ANTs
168168
fixed, moving = _prepare_registration_data(
@@ -180,7 +180,7 @@ def estimate(
180180
data.motion_affines,
181181
data.affine,
182182
data.dataobj.shape[:3],
183-
data_test[1][3],
183+
data_test[-1][3],
184184
i_iter,
185185
i,
186186
ptmp_dir,

test/test_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,19 @@ def test_two_initialisations(datadir):
154154

155155
# Direct initialisation
156156
model1 = model.AverageDWIModel(
157-
gtab=data_train[1],
157+
gtab=data_train[-1],
158158
S0=dmri_dataset.bzero,
159159
th_low=100,
160160
th_high=1000,
161161
bias=False,
162162
stat="mean",
163163
)
164-
model1.fit(data_train[0], gtab=data_train[1])
165-
predicted1 = model1.predict(data_test[1])
164+
model1.fit(data_train[0], gtab=data_train[-1])
165+
predicted1 = model1.predict(data_test[-1])
166166

167167
# Initialisation via ModelFactory
168168
model2 = model.ModelFactory.init(
169-
gtab=data_train[1],
169+
gtab=data_train[-1],
170170
model="avgdwi",
171171
S0=dmri_dataset.bzero,
172172
th_low=100,
@@ -176,9 +176,9 @@ def test_two_initialisations(datadir):
176176
)
177177

178178
with pytest.raises(ModelNotFittedError):
179-
model2.predict(data_test[1])
179+
model2.predict(data_test[-1])
180180

181-
model2.fit(data_train[0], gtab=data_train[1])
182-
predicted2 = model2.predict(data_test[1])
181+
model2.fit(data_train[0], gtab=data_train[-1])
182+
predicted2 = model2.predict(data_test[-1])
183183

184184
assert np.all(predicted1 == predicted2)

test/test_splitting.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def test_lovo_split(datadir):
3737
3838
Returns:
3939
None
40+
4041
"""
4142
data = DWI.from_filename(datadir / "dwi.h5")
4243

@@ -52,11 +53,11 @@ def test_lovo_split(datadir):
5253
data.gradients[..., index] = 1
5354

5455
# Apply the lovo_split function at the specified index
55-
(train_data, train_gradients), (test_data, test_gradients) = lovo_split(data, index)
56+
train_data, test_data = lovo_split(data, index)
5657

5758
# Check if the test data contains only 1s
5859
# and the train data contains only 0s after the split
59-
assert np.all(test_data == 1)
60-
assert np.all(train_data == 0)
61-
assert np.all(test_gradients == 1)
62-
assert np.all(train_gradients == 0)
60+
assert np.all(test_data[0] == 1)
61+
assert np.all(train_data[0] == 0)
62+
assert np.all(test_data[-1] == 1)
63+
assert np.all(train_data[-1] == 0)

0 commit comments

Comments
 (0)