Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions aucmedi/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pathos.helpers import mp # instead of 'import multiprocessing as mp'
import numpy as np
import shutil
import pandas as pd
# Internal libraries
from aucmedi import DataGenerator, NeuralNetwork
from aucmedi.sampling import sampling_kfold
Expand Down Expand Up @@ -97,7 +98,7 @@ class Bagging:
An Analysis on Ensemble Learning optimized Medical Image Classification with Deep Convolutional Neural Networks.
arXiv e-print: [https://arxiv.org/abs/2201.11440](https://arxiv.org/abs/2201.11440)
"""
def __init__(self, model, k_fold=3):
def __init__(self, model, k_fold=3, seed=None):
""" Initialization function for creating a Bagging object.

Args:
Expand All @@ -108,6 +109,7 @@ def __init__(self, model, k_fold=3):
self.model_template = model
self.k_fold = k_fold
self.cache_dir = None
self.seed = seed

# Set multiprocessing method to spawn
mp.set_start_method("spawn", force=True)
Expand Down Expand Up @@ -149,7 +151,39 @@ def train(self, training_generator, epochs=20, iterations=None,

# Apply cross-validaton sampling
cv_sampling = sampling_kfold(x, y, m, n_splits=self.k_fold,
stratified=True, iterative=True)
stratified=True, iterative=True, seed=self.seed)

# Save cross-validation sampling in list
records = []

# Sequentially iterate over all folds
for i, fold in enumerate(cv_sampling):
if len(fold) == 4:
train_x, train_y, val_x, val_y = fold
else:
train_x, train_y, _, val_x, val_y, _ = fold

# Training data
for sample, label in zip(train_x, train_y):
records.append({
'fold': i,
'subset': 'train',
'sample': sample,
'label': int(np.argmax(label))
})

# Validation data
for sample, label in zip(val_x, val_y):
records.append({
'fold': i,
'subset': 'val',
'sample': sample,
'label': int(np.argmax(label))
})

# Save as CSV
df = pd.DataFrame(records)
df.to_csv(os.path.join(self.cache_dir.name, "kfold_split.csv"), index=False)

# Sequentially iterate over all folds
for i, fold in enumerate(cv_sampling):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,19 @@ def test_Bagging_dump(self):
target = tempfile.TemporaryDirectory(prefix="tmp.aucmedi.",
suffix=".model")
self.assertTrue(len(os.listdir(target.name))==0)
self.assertTrue(len(os.listdir(el.cache_dir.name))==4)
self.assertTrue(len(os.listdir(el.cache_dir.name))==5)
origin = el.cache_dir.name
# Dump model
target_dir = os.path.join(target.name, "test")
el.dump(target_dir)
self.assertTrue(len(os.listdir(target_dir))==4)
self.assertTrue(len(os.listdir(target_dir))==5)
self.assertFalse(os.path.exists(origin))
target_two = tempfile.TemporaryDirectory(prefix="tmp.aucmedi.",
suffix=".model")
target_dir_two = os.path.join(target_two.name, "test")
el.dump(target_dir_two)
self.assertTrue(len(os.listdir(target_dir_two))==4)
self.assertTrue(len(os.listdir(target_dir))==4)
self.assertTrue(len(os.listdir(target_dir_two))==5)
self.assertTrue(len(os.listdir(target_dir))==5)
self.assertTrue(os.path.exists(target_dir))

def test_Bagging_load(self):
Expand Down
Loading