@@ -36,6 +36,8 @@ def train_class_pathmnist(
3636 hidden_channels : int ,
3737 gpu : bool ,
3838 gpu_index : int ,
39+ max_epochs : int ,
40+ dry : bool ,
3941):
4042 comment = "PathMNIST"
4143 comment += f"_hidden_{ hidden_channels } "
@@ -44,7 +46,7 @@ def train_class_pathmnist(
4446 comment += f"_AM_{ alive_mask } "
4547 comment += f"_TE_{ use_temporal_encoding } "
4648
47- writer = SummaryWriter (comment = comment )
49+ writer = SummaryWriter (comment = comment ) if not dry else None
4850
4951 device = get_compute_device (f"cuda:{ gpu_index } " if gpu else "cpu" )
5052
@@ -80,22 +82,23 @@ def train_class_pathmnist(
8082 pad_noise = pad_noise ,
8183 use_temporal_encoding = use_temporal_encoding ,
8284 class_names = list (INFO ["pathmnist" ]["label" ].values ()),
85+ training_timesteps = 32 ,
86+ inference_timesteps = 32 ,
8387 )
8488 trainer = BasicNCATrainer (
8589 nca ,
86- WEIGHTS_PATH / "classification_pathmnist" ,
90+ WEIGHTS_PATH / "classification_pathmnist" if not dry else None ,
8791 batch_repeat = 2 ,
88- max_epochs = 10 ,
92+ max_epochs = max_epochs ,
8993 gradient_clipping = gradient_clipping ,
90- steps_range = (32 , 33 ),
91- steps_validation = 32 ,
9294 )
9395 trainer .train (
9496 loader_train ,
9597 loader_val ,
9698 summary_writer = writer ,
9799 )
98- writer .close ()
100+ if writer is not None :
101+ writer .close ()
99102
100103
101104@click .command ()
@@ -107,12 +110,18 @@ def train_class_pathmnist(
107110@click .option (
108111 "--gpu-index" , type = int , default = 0 , help = "Index of GPU to use, if --gpu in use."
109112)
110- def main (batch_size , hidden_channels , gpu : bool , gpu_index : int ):
113+ @click .option ("--max-epochs" , "-E" , type = int , default = 10 )
114+ @click .option ("--dry" , "-D" , is_flag = True )
115+ def main (
116+ batch_size , hidden_channels , gpu : bool , gpu_index : int , max_epochs : int , dry : bool
117+ ):
111118 train_class_pathmnist (
112119 batch_size = batch_size ,
113120 hidden_channels = hidden_channels ,
114121 gpu = gpu ,
115122 gpu_index = gpu_index ,
123+ max_epochs = max_epochs ,
124+ dry = dry ,
116125 )
117126
118127
0 commit comments