Skip to content

Commit 493e5e3

Browse files
author
rizoudal
committed
refactor: update DataGenerator and architectures for resize=(64, 64, 64)
- refactored architectures (DenseNet121, DenseNet169, DenseNet201, MobileNetV2, ResNeXt50, ResNeXt101) - updated input_shape of each NeuralNetwork to (64, 64, 64)
1 parent 560f1cb commit 493e5e3

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

tests/test_architectures_volume.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def setUpClass(self):
6666
loader=numpy_loader, two_dim=False,
6767
grayscale=True, batch_size=1)
6868

69+
self.datagen_HU_64 = DataGenerator(self.sampleList_hu,
70+
self.tmp_data.name,
71+
labels=self.labels_ohe,
72+
resize=(64, 64, 64),
73+
loader=numpy_loader, two_dim=False,
74+
grayscale=True, batch_size=1)
75+
6976
#-------------------------------------------------#
7077
# Architecture: Vanilla #
7178
#-------------------------------------------------#
@@ -86,11 +93,11 @@ def test_Vanilla(self):
8693
#-------------------------------------------------#
8794
def test_DenseNet121(self):
8895
arch = DenseNet121(Classifier(n_labels=4), channels=1,
89-
input_shape=(32, 32, 32))
96+
input_shape=(64, 64, 64))
9097
model = NeuralNetwork(n_labels=4, channels=1, architecture=arch)
91-
model.predict(self.datagen_HU)
98+
model.predict(self.datagen_HU_64)
9299
model = NeuralNetwork(n_labels=4, channels=3, architecture="3D.DenseNet121",
93-
input_shape=(32, 32, 32))
100+
input_shape=(64, 64, 64))
94101
try : model.model.summary()
95102
except : raise Exception()
96103
self.assertTrue(supported_standardize_mode["DenseNet121"] == "torch")
@@ -101,11 +108,11 @@ def test_DenseNet121(self):
101108
#-------------------------------------------------#
102109
def test_DenseNet169(self):
103110
arch = DenseNet169(Classifier(n_labels=4), channels=1,
104-
input_shape=(32, 32, 32))
111+
input_shape=(64, 64, 64))
105112
model = NeuralNetwork(n_labels=4, channels=1, architecture=arch)
106-
model.predict(self.datagen_HU)
113+
model.predict(self.datagen_HU_64)
107114
model = NeuralNetwork(n_labels=4, channels=3, architecture="3D.DenseNet169",
108-
input_shape=(32, 32, 32))
115+
input_shape=(64, 64, 64))
109116
try : model.model.summary()
110117
except : raise Exception()
111118
self.assertTrue(supported_standardize_mode["DenseNet169"] == "torch")
@@ -116,11 +123,11 @@ def test_DenseNet169(self):
116123
#-------------------------------------------------#
117124
def test_DenseNet201(self):
118125
arch = DenseNet201(Classifier(n_labels=4), channels=1,
119-
input_shape=(32, 32, 32))
126+
input_shape=(64, 64, 64))
120127
model = NeuralNetwork(n_labels=4, channels=1, architecture=arch)
121-
model.predict(self.datagen_HU)
128+
model.predict(self.datagen_HU_64)
122129
model = NeuralNetwork(n_labels=4, channels=3, architecture="3D.DenseNet201",
123-
input_shape=(32, 32, 32))
130+
input_shape=(64, 64, 64))
124131
try : model.model.summary()
125132
except : raise Exception()
126133
self.assertTrue(supported_standardize_mode["DenseNet201"] == "torch")
@@ -206,11 +213,11 @@ def test_ResNet152(self):
206213
#-------------------------------------------------#
207214
def test_ResNeXt50(self):
208215
arch = ResNeXt50(Classifier(n_labels=4), channels=1,
209-
input_shape=(32, 32, 32))
216+
input_shape=(64, 64, 64))
210217
model = NeuralNetwork(n_labels=4, channels=1, architecture=arch)
211-
model.predict(self.datagen_HU)
218+
model.predict(self.datagen_HU_64)
212219
model = NeuralNetwork(n_labels=4, channels=3, architecture="3D.ResNeXt50",
213-
input_shape=(32, 32, 32))
220+
input_shape=(64, 64, 64))
214221
try : model.model.summary()
215222
except : raise Exception()
216223
self.assertTrue(supported_standardize_mode["ResNeXt50"] == "grayscale")
@@ -221,11 +228,11 @@ def test_ResNeXt50(self):
221228
#-------------------------------------------------#
222229
def test_ResNeXt101(self):
223230
arch = ResNeXt101(Classifier(n_labels=4), channels=1,
224-
input_shape=(32, 32, 32))
231+
input_shape=(64, 64, 64))
225232
model = NeuralNetwork(n_labels=4, channels=1, architecture=arch)
226-
model.predict(self.datagen_HU)
233+
model.predict(self.datagen_HU_64)
227234
model = NeuralNetwork(n_labels=4, channels=3, architecture="3D.ResNeXt101",
228-
input_shape=(32, 32, 32))
235+
input_shape=(64, 64, 64))
229236
try : model.model.summary()
230237
except : raise Exception()
231238
self.assertTrue(supported_standardize_mode["ResNeXt101"] == "grayscale")
@@ -283,7 +290,7 @@ def test_MobileNetV2(self):
283290
arch = MobileNetV2(Classifier(n_labels=4), channels=1,
284291
input_shape=(64, 64, 64))
285292
model = NeuralNetwork(n_labels=4, channels=1, architecture=arch)
286-
model.predict(self.datagen_HU)
293+
model.predict(self.datagen_HU_64)
287294
model = NeuralNetwork(n_labels=4, channels=3, architecture="3D.MobileNetV2",
288295
input_shape=(64, 64, 64))
289296
try : model.model.summary()

0 commit comments

Comments
 (0)