Skip to content

Commit edb0c36

Browse files
committed
create factory function for k-fold cross validation splits
1 parent 14410d9 commit edb0c36

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

kvasircapsuleloader/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class KvasirCapsuleMetadata:
1313
This is basically an abstraction for the records in metadata.csv.
1414
"""
1515

16-
def __init__(self):
16+
def __init__(self) -> None:
1717
self._data = pd.read_csv(KVASIR_CAPSULE_PATH / "metadata.csv", delimiter=";")
1818
self.video_ids = self._data.video_id
1919
self.samples: List[KvasirCapsuleSample] = []

kvasircapsuleloader/split.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,13 @@ class PatientRatioSplit:
1717
Generalized split by ratio that respects patient IDs (Experimental)
1818
"""
1919

20-
def __init__(
21-
self,
22-
train: float,
23-
**ratios,
24-
):
20+
def __init__(self, **ratios):
2521
"""
26-
:param train: Ratio for training set (other ratios are optional)
27-
:type train: float
2822
:raises ValueError: If ratios don't add up to 1.0
2923
"""
30-
if sum(list(ratios.values())) + train != 1.0:
24+
if sum(list(ratios.values())) != 1.0:
3125
raise ValueError("Split ratios need to sum up to 1.")
32-
self.ratio_train = train
3326
self._ratios = ratios
34-
self._ratios.update(train=train)
3527

3628
def generate(
3729
self,
@@ -65,19 +57,22 @@ def generate(
6557
)
6658
logging.warning(f"Ignoring class {finding_class}.")
6759
continue
60+
61+
first_phase = list(self._ratios.keys())[0]
62+
N = {
63+
first_phase: N_patients
64+
}
6865
# make sure that sets represent at least one patient
69-
N = {}
70-
N["train"] = N_patients
7166
for phase, ratio in self._ratios.items():
72-
if phase == "train":
67+
if phase == first_phase:
7368
continue
7469
N[phase] = max(int(np.round(N_patients * ratio)), 1)
75-
N["train"] -= N[phase]
76-
assert N["train"] > 0
70+
N[first_phase] -= N[phase]
71+
assert N[first_phase] > 0
7772
idx = np.arange(N_patients)
7873
if strategy == "sort":
7974
# sort indices descending by number of samples
80-
# --> training set will contain most samples
75+
# --> first set will contain most samples
8176
idx = np.argsort([len(p) for p in patients])[::-1]
8277
else:
8378
np.random.shuffle(idx)
@@ -107,8 +102,7 @@ def load(path: Path, metadata: KvasirCapsuleMetadata) -> "PatientRatioSplit":
107102
with open(path, "r") as f:
108103
data = json.load(f)
109104
split = PatientRatioSplit(
110-
data["ratios"]["train"],
111-
**{k: v for k, v in data["ratios"].items() if k != "train"},
105+
**data["ratios"]
112106
)
113107
split._seed = data["seed"]
114108
S = metadata.samples_by_filename()
@@ -133,3 +127,10 @@ def save(self, path: Path):
133127
}
134128
with open(path, "w") as f:
135129
json.dump(data, f)
130+
131+
132+
def make_kfold_split(k: int):
133+
assert k > 0
134+
return PatientRatioSplit(
135+
**{ f"fold{i}": 1 / k for i in range(k) }
136+
)

0 commit comments

Comments
 (0)