Skip to content

Commit c656db9

Browse files
authored
Merge branch 'main' into enh/data-refactor
2 parents e533ff5 + fdab87b commit c656db9

File tree

4 files changed

+11
-11
lines changed

4 files changed

+11
-11
lines changed

src/nifreeze/model/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
TrivialModel,
2929
)
3030
from nifreeze.model.dmri import (
31-
AverageDWModel,
31+
AverageDWIModel,
3232
DKIModel,
3333
DTIModel,
3434
GPModel,
@@ -38,7 +38,7 @@
3838
__all__ = (
3939
"ModelFactory",
4040
"AverageModel",
41-
"AverageDWModel",
41+
"AverageDWIModel",
4242
"DKIModel",
4343
"DTIModel",
4444
"GPModel",

src/nifreeze/model/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def init(model="DTI", **kwargs):
5353
return TrivialModel(predicted=kwargs.pop("S0"), gtab=kwargs.pop("gtab"))
5454

5555
if model.lower() in ("avgdwi", "averagedwi", "meandwi"):
56-
from nifreeze.model.dmri import AverageDWModel
56+
from nifreeze.model.dmri import AverageDWIModel
5757

58-
return AverageDWModel(**kwargs)
58+
return AverageDWIModel(**kwargs)
5959

6060
if model.lower() in ("avg", "average", "mean"):
6161
return AverageModel(**kwargs)

src/nifreeze/model/dmri.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def predict(self, gradient=None, **kwargs):
218218
return retval
219219

220220

221-
class AverageDWModel(BaseDWIModel):
222-
"""A trivial model that returns an average map."""
221+
class AverageDWIModel(BaseDWIModel):
222+
"""A trivial model that returns an average DWI volume."""
223223

224224
__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted")
225225

test/test_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ def test_average_model():
7878

7979
data *= gtab[:, -1]
8080

81-
tmodel_mean = model.AverageDWModel(gtab=gtab, bias=False, stat="mean")
82-
tmodel_median = model.AverageDWModel(gtab=gtab, bias=False, stat="median")
83-
tmodel_1000 = model.AverageDWModel(gtab=gtab, bias=False, th_high=1000, th_low=900)
84-
tmodel_2000 = model.AverageDWModel(
81+
tmodel_mean = model.AverageDWIModel(gtab=gtab, bias=False, stat="mean")
82+
tmodel_median = model.AverageDWIModel(gtab=gtab, bias=False, stat="median")
83+
tmodel_1000 = model.AverageDWIModel(gtab=gtab, bias=False, th_high=1000, th_low=900)
84+
tmodel_2000 = model.AverageDWIModel(
8585
gtab=gtab,
8686
bias=False,
8787
th_high=2000,
@@ -153,7 +153,7 @@ def test_two_initialisations(datadir):
153153
data_train, data_test = lovo_split(dmri_dataset, 10)
154154

155155
# Direct initialisation
156-
model1 = model.AverageDWModel(
156+
model1 = model.AverageDWIModel(
157157
gtab=data_train[1],
158158
S0=dmri_dataset.bzero,
159159
th_low=100,

0 commit comments

Comments
 (0)