1111
1212
1313class TrainValRecord :
14+ """
15+ Helper class, storing a training / validation data split to generate
16+ respective DataLoader objects.
17+ """
18+
1419 def __init__ (
1520 self ,
1621 train : List [str ],
1722 val : List [str ],
1823 ):
24+ """
25+ Constructor.
26+
27+ :param train (List[str]): List of training image file paths
28+ :param val (List[str]): List of validation image file paths
29+ """
1930 self .train = train
2031 self .val = val
2132
@@ -26,6 +37,10 @@ def dataloaders(
2637 transform = None ,
2738 batch_sizes = None ,
2839 ):
40+ """
41+ Generate a pair of training and validation DataLoader objects, based on
42+ a given DataSet subtype.
43+ """
2944 if batch_sizes is None :
3045 batch_sizes = {"train" : 8 , "val" : 8 }
3146 dataset_train = DatasetType (path , self .train , transform )
@@ -45,7 +60,14 @@ def dataloaders(
4560
4661
4762class SplitDefinition :
63+ """
64+ Stores a k-fold cross-validation split.
65+ """
66+
4867 def __init__ (self ):
68+ """
69+ Constructor.
70+ """
4971 self .folds : List [None | TrainValRecord ] = []
5072 self .dataloader_test = None
5173
@@ -55,18 +77,21 @@ def read(path: PosixPath):
5577 Reads json files with split definitions, similar to those created by nnUNet.
5678
5779 Format is like
58- [
59- {
60- "train": [ "filename0", "filename1",... ]
61- "val": [ "filename2", "filename3",... ]
62- },
63- {
64- ...
65- }
66- ]
67-
68- Args:
69- path (PosixPath): Path to JSON file containing split definition.
80+
81+ .. highlight:: python
82+ .. code-block:: python
83+
84+ [
85+ {
86+ "train": [ "filename0", "filename1",... ]
87+ "val": [ "filename2", "filename3",... ]
88+ },
89+ {
90+ ...
91+ }
92+ ]
93+
94+ :param path [PosixPath]: Path to JSON file containing split definition.
7095 """
7196 with open (path , "r" ) as f :
7297 d = json .load (f )
@@ -89,9 +114,15 @@ def __getitem__(self, idx) -> TrainValRecord:
89114
90115class KFoldCrossValidationTrainer :
91116 def __init__ (self , trainer : BasicNCATrainer , split : SplitDefinition ):
117+ """
118+ Constructor.
119+
120+ :param trainer [BasicNCATrainer]: BasicNCATrainer, to train each individual fold.
121+ :param split [SplitDefinition]: Definition of the split used for k-fold cross-training.
122+ """
92123 self .trainer = trainer
93124 self .model_prototype = copy .deepcopy (trainer .nca )
94- self .model_name = trainer .model_path .with_suffix ('' )
125+ self .model_name = trainer .model_path .with_suffix ("" )
95126 self .split = split
96127
97128 def train (
@@ -102,6 +133,18 @@ def train(
102133 batch_sizes : None | Dict = None ,
103134 save_every : int | None = None ,
104135 ) -> List [TrainingSummary ]:
136+ """
137+ Run training loop with a single function call.
138+
139+ :param DatasetType [Type]: Type of dataset class to use.
140+ :param datapath [Path]: _description_
141+ :param transform: Data transform, e.g. initialized via Albumentations.
142+ :param batch_sizes: Dict of batch sizes per set, e.g. {"train": 8, "val": 16}. Defaults to None.
143+ :param save_every [int]: _description_. Defaults to None.
144+ :param plot_function: Plot function override. If None, use model's default. Defaults to None.
145+
146+ :returns [List[TrainingSummary]]: List of TrainingSummary objects, one per fold.
147+ """
105148 k = len (self .split )
106149 summaries = []
107150 for i in range (k ):
0 commit comments