Skip to content

Commit 6b6a6c7

Browse files
committed
store 'strategy' attribute in patientratiosplit serialization
1 parent edb0c36 commit 6b6a6c7

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

kvasircapsuleloader/split.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def generate(
4242
:type seed: int, optional
4343
"""
4444
self._seed = seed
45+
self._strategy = strategy
4546
self.samples: Dict[str, List[KvasirCapsuleSample]] = {
4647
key: [] for key in self._ratios
4748
}
@@ -105,6 +106,7 @@ def load(path: Path, metadata: KvasirCapsuleMetadata) -> "PatientRatioSplit":
105106
**data["ratios"]
106107
)
107108
split._seed = data["seed"]
109+
split._strategy = data["strategy"]
108110
S = metadata.samples_by_filename()
109111
for phase in split._ratios:
110112
split.samples[phase] = [S[filename] for filename in data["samples"][phase]]
@@ -120,6 +122,7 @@ def save(self, path: Path):
120122
data = {
121123
"ratios": {**self._ratios},
122124
"seed": self._seed,
125+
"strategy": self._strategy,
123126
"samples": {
124127
phase: [sample.filename for sample in self.samples[phase]]
125128
for phase in self._ratios
@@ -130,6 +133,12 @@ def save(self, path: Path):
130133

131134

132135
def make_kfold_split(k: int):
136+
"""
137+
Factoy function that creates a split definition for k-fold cross-validation.
138+
139+
:param k: Number of folds, must be > 0
140+
:type k: int
141+
"""
133142
assert k > 0
134143
return PatientRatioSplit(
135144
**{ f"fold{i}": 1 / k for i in range(k) }

0 commit comments

Comments
 (0)