@@ -82,8 +82,7 @@ def info(self) -> str:
8282 Shows a markdown-formatted info string with training parameters.
8383 Useful for showing info on tensorboard to keep track of parameter changes.
8484
85- Returns:
86- str: Markdown-formatted info string.
85+ :returns [str]: Markdown-formatted info string.
8786 """
8887 s = "BasicNCATrainer Info\n "
8988 s += "-------------------\n "
@@ -102,6 +101,53 @@ def info(self) -> str:
102101 s += f"**{ attribute_f } :** { getattr (self , attribute )} \n "
103102 return s
104103
104+ def train_iteration (
105+ self ,
106+ x : torch .Tensor ,
107+ y : torch .Tensor ,
108+ steps : int ,
109+ optimizer : torch .optim .Optimizer ,
110+ scheduler : torch .optim .lr_scheduler .LRScheduler ,
111+ total_batch_iterations : int ,
112+ summary_writer ,
113+ ) -> torch .Tensor :
114+ """
115+ Run a single training iteration.
116+
117+ :param x [Tensor]: Input training images.
118+ :param y [Tensor]: Input training labels.
119+ :param steps [int]: Number of NCA inference time steps.
120+ :param optimizer [torch.optim.Optimizer]: Optimizer.
121+ :param scheduler [torch.optim.lr_scheduler.LRScheduler]: Scheduler.
122+ :param total_batch_iterations [int]: Total training batch iterations
123+
124+ :returns [Tensor]: Predicted image.
125+ """
126+ device = self .nca .device
127+ self .nca .train ()
128+ optimizer .zero_grad ()
129+ x_pred = x .clone ().to (self .nca .device )
130+ if self .truncate_backprop :
131+ for step in range (steps ):
132+ x_pred = self .nca (x_pred , steps = 1 )
133+ if step < steps - 10 :
134+ x_pred .detach ()
135+ else :
136+ x_pred = self .nca (x_pred , steps = steps )
137+ losses = self .nca .loss (x_pred , y .to (device ))
138+ losses ["total" ].backward ()
139+
140+ if self .gradient_clipping :
141+ torch .nn .utils .clip_grad_norm_ (self .nca .parameters (), 1.0 )
142+ optimizer .step ()
143+ scheduler .step ()
144+ if summary_writer :
145+ for key in losses :
146+ summary_writer .add_scalar (
147+ f"Loss/train_{ key } " , losses [key ], total_batch_iterations
148+ )
149+ return x_pred
150+
105151 def train (
106152 self ,
107153 dataloader_train : DataLoader ,
@@ -117,19 +163,14 @@ def train(
117163 """
118164 Execute basic NCA training loop with a single function call.
119165
120- Args:
121- dataloader_train (DataLoader): Training DataLoader
122- dataloader_val (DataLoader): Validation DataLoader
123- save_every (int, optional):
124- How often to save model state (in epochs). Useful for very small datasets, like growing lizard.
125- summary_writer (SummaryWriter, optional):
126- Tensorboard SummaryWriter. Defaults to None.
127- plot_function (Callable[ [np.ndarray, np.ndarray, np.ndarray, BasicNCAModel], Figure ], optional):
128- Plot function override. If None, use model's default. Defaults to None.
129- earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.
130-
131- Returns:
132- TrainingSummary: TrainingSummary object.
166+ :param dataloader_train [DataLoader]: Training DataLoader
167+ :param dataloader_val [DataLoader]: Validation DataLoader
168+ :param save_every [int]: How often to save model state (in epochs). Useful for very small datasets, like growing lizard.
169+ :param summary_writer [SummaryWriter] Tensorboard SummaryWriter. Defaults to None.
170+ :param plot_function: Plot function override. If None, use model's default. Defaults to None.
171+ :param earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.
172+
173+ :returns [TrainingSummary]: TrainingSummary object.
133174 """
134175 logging .basicConfig (encoding = "utf-8" , level = logging .INFO )
135176
@@ -154,53 +195,6 @@ def train(
154195 else :
155196 best_path = None
156197
157- def train_iteration (
158- x : torch .Tensor ,
159- y : torch .Tensor ,
160- steps : int ,
161- optimizer : torch .optim .Optimizer ,
162- scheduler : torch .optim .lr_scheduler .LRScheduler ,
163- total_batch_iterations : int ,
164- ) -> torch .Tensor :
165- """
166- Run a single training iteration.
167-
168- Args:
169- x (torch.Tensor): Input training images.
170- y (torch.Tensor): Input training labels.
171- steps (int): Number of NCA inference time steps.
172- optimizer (torch.optim.Optimizer): Optimizer.
173- scheduler (torch.optim.lr_scheduler.LRScheduler): Scheduler.
174- total_batch_iterations (int): Total training batch iterations
175-
176- Returns:
177- torch.Tensor: Predicted image.
178- """
179- device = self .nca .device
180- self .nca .train ()
181- optimizer .zero_grad ()
182- x_pred = x .clone ().to (self .nca .device )
183- if self .truncate_backprop :
184- for step in range (steps ):
185- x_pred = self .nca (x_pred , steps = 1 )
186- if step < steps - 10 :
187- x_pred .detach ()
188- else :
189- x_pred = self .nca (x_pred , steps = steps )
190- losses = self .nca .loss (x_pred , y .to (device ))
191- losses ["total" ].backward ()
192-
193- if self .gradient_clipping :
194- torch .nn .utils .clip_grad_norm_ (self .nca .parameters (), 1.0 )
195- optimizer .step ()
196- scheduler .step ()
197- if summary_writer :
198- for key in losses :
199- summary_writer .add_scalar (
200- f"Loss/train_{ key } " , losses [key ], total_batch_iterations
201- )
202- return x_pred
203-
204198 # MAIN LOOP
205199 total_batch_iterations = 0
206200 for iteration in tqdm (range (self .max_epochs ), desc = "Epochs" ):
@@ -251,16 +245,26 @@ def train_iteration(
251245 y = torch .cat (self .batch_repeat * [y ])
252246
253247 steps = np .random .randint (* self .steps_range )
254- x_pred = train_iteration (
255- x , y , steps , optimizer , scheduler , total_batch_iterations
248+ x_pred = self .train_iteration (
249+ x ,
250+ y ,
251+ steps ,
252+ optimizer ,
253+ scheduler ,
254+ total_batch_iterations ,
255+ summary_writer ,
256256 )
257257 if self .p_retain_pool > 0.0 :
258258 x_previous = x_pred
259259 total_batch_iterations += 1
260260
261261 with torch .no_grad ():
262262 # VISUALIZATION
263- if plot_function and summary_writer and (iteration + 1 ) % save_every == 0 :
263+ if (
264+ plot_function
265+ and summary_writer
266+ and (iteration + 1 ) % save_every == 0
267+ ):
264268 figure = plot_function (
265269 x .detach ().cpu ().numpy (),
266270 x_pred .detach ().cpu ().numpy (),
0 commit comments