@@ -42,6 +42,7 @@ def generate(
4242 :type seed: int, optional
4343 """
4444 self ._seed = seed
45+ self ._strategy = strategy
4546 self .samples : Dict [str , List [KvasirCapsuleSample ]] = {
4647 key : [] for key in self ._ratios
4748 }
@@ -105,6 +106,7 @@ def load(path: Path, metadata: KvasirCapsuleMetadata) -> "PatientRatioSplit":
105106 ** data ["ratios" ]
106107 )
107108 split ._seed = data ["seed" ]
109+ split ._strategy = data ["strategy" ]
108110 S = metadata .samples_by_filename ()
109111 for phase in split ._ratios :
110112 split .samples [phase ] = [S [filename ] for filename in data ["samples" ][phase ]]
@@ -120,6 +122,7 @@ def save(self, path: Path):
120122 data = {
121123 "ratios" : {** self ._ratios },
122124 "seed" : self ._seed ,
125+ "strategy" : self ._strategy ,
123126 "samples" : {
124127 phase : [sample .filename for sample in self .samples [phase ]]
125128 for phase in self ._ratios
@@ -130,6 +133,12 @@ def save(self, path: Path):
130133
131134
132135def make_kfold_split (k : int ):
136+ """
137+ Factoy function that creates a split definition for k-fold cross-validation.
138+
139+ :param k: Number of folds, must be > 0
140+ :type k: int
141+ """
133142 assert k > 0
134143 return PatientRatioSplit (
135144 ** { f"fold{ i } " : 1 / k for i in range (k ) }
0 commit comments