@@ -17,21 +17,13 @@ class PatientRatioSplit:
1717 Generalized split by ratio that respects patient IDs (Experimental)
1818 """
1919
20- def __init__ (
21- self ,
22- train : float ,
23- ** ratios ,
24- ):
20+ def __init__ (self , ** ratios ):
2521 """
26- :param train: Ratio for training set (other ratios are optional)
27- :type train: float
2822 :raises ValueError: If ratios don't add up to 1.0
2923 """
30- if sum (list (ratios .values ())) + train != 1.0 :
24+ if sum (list (ratios .values ())) != 1.0 :
3125 raise ValueError ("Split ratios need to sum up to 1." )
32- self .ratio_train = train
3326 self ._ratios = ratios
34- self ._ratios .update (train = train )
3527
3628 def generate (
3729 self ,
@@ -65,19 +57,22 @@ def generate(
6557 )
6658 logging .warning (f"Ignoring class { finding_class } ." )
6759 continue
60+
61+ first_phase = list (self ._ratios .keys ())[0 ]
62+ N = {
63+ first_phase : N_patients
64+ }
6865 # make sure that sets represent at least one patient
69- N = {}
70- N ["train" ] = N_patients
7166 for phase , ratio in self ._ratios .items ():
72- if phase == "train" :
67+ if phase == first_phase :
7368 continue
7469 N [phase ] = max (int (np .round (N_patients * ratio )), 1 )
75- N ["train" ] -= N [phase ]
76- assert N ["train" ] > 0
70+ N [first_phase ] -= N [phase ]
71+ assert N [first_phase ] > 0
7772 idx = np .arange (N_patients )
7873 if strategy == "sort" :
7974 # sort indices descending by number of samples
80- # --> training set will contain most samples
75+ # --> first set will contain most samples
8176 idx = np .argsort ([len (p ) for p in patients ])[::- 1 ]
8277 else :
8378 np .random .shuffle (idx )
@@ -107,8 +102,7 @@ def load(path: Path, metadata: KvasirCapsuleMetadata) -> "PatientRatioSplit":
107102 with open (path , "r" ) as f :
108103 data = json .load (f )
109104 split = PatientRatioSplit (
110- data ["ratios" ]["train" ],
111- ** {k : v for k , v in data ["ratios" ].items () if k != "train" },
105+ ** data ["ratios" ]
112106 )
113107 split ._seed = data ["seed" ]
114108 S = metadata .samples_by_filename ()
@@ -133,3 +127,10 @@ def save(self, path: Path):
133127 }
134128 with open (path , "w" ) as f :
135129 json .dump (data , f )
130+
131+
132+ def make_kfold_split (k : int ):
133+ assert k > 0
134+ return PatientRatioSplit (
135+ ** { f"fold{ i } " : 1 / k for i in range (k ) }
136+ )
0 commit comments