Skip to content

Commit 3fe650d

Browse files
committed
move train_iteration out of training loop to reduce complexity
1 parent 5810fff commit 3fe650d

File tree

1 file changed

+69
-65
lines changed

1 file changed

+69
-65
lines changed

ncalab/training/trainer.py

Lines changed: 69 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def info(self) -> str:
8282
Shows a markdown-formatted info string with training parameters.
8383
Useful for showing info on tensorboard to keep track of parameter changes.
8484
85-
Returns:
86-
str: Markdown-formatted info string.
85+
:returns [str]: Markdown-formatted info string.
8786
"""
8887
s = "BasicNCATrainer Info\n"
8988
s += "-------------------\n"
@@ -102,6 +101,53 @@ def info(self) -> str:
102101
s += f"**{attribute_f}:** {getattr(self, attribute)}\n"
103102
return s
104103

104+
def train_iteration(
105+
self,
106+
x: torch.Tensor,
107+
y: torch.Tensor,
108+
steps: int,
109+
optimizer: torch.optim.Optimizer,
110+
scheduler: torch.optim.lr_scheduler.LRScheduler,
111+
total_batch_iterations: int,
112+
summary_writer,
113+
) -> torch.Tensor:
114+
"""
115+
Run a single training iteration.
116+
117+
:param x [Tensor]: Input training images.
118+
:param y [Tensor]: Input training labels.
119+
:param steps [int]: Number of NCA inference time steps.
120+
:param optimizer [torch.optim.Optimizer]: Optimizer.
121+
:param scheduler [torch.optim.lr_scheduler.LRScheduler]: Scheduler.
122+
:param total_batch_iterations [int]: Total training batch iterations
123+
124+
:returns [Tensor]: Predicted image.
125+
"""
126+
device = self.nca.device
127+
self.nca.train()
128+
optimizer.zero_grad()
129+
x_pred = x.clone().to(self.nca.device)
130+
if self.truncate_backprop:
131+
for step in range(steps):
132+
x_pred = self.nca(x_pred, steps=1)
133+
if step < steps - 10:
134+
x_pred.detach()
135+
else:
136+
x_pred = self.nca(x_pred, steps=steps)
137+
losses = self.nca.loss(x_pred, y.to(device))
138+
losses["total"].backward()
139+
140+
if self.gradient_clipping:
141+
torch.nn.utils.clip_grad_norm_(self.nca.parameters(), 1.0)
142+
optimizer.step()
143+
scheduler.step()
144+
if summary_writer:
145+
for key in losses:
146+
summary_writer.add_scalar(
147+
f"Loss/train_{key}", losses[key], total_batch_iterations
148+
)
149+
return x_pred
150+
105151
def train(
106152
self,
107153
dataloader_train: DataLoader,
@@ -117,19 +163,14 @@ def train(
117163
"""
118164
Execute basic NCA training loop with a single function call.
119165
120-
Args:
121-
dataloader_train (DataLoader): Training DataLoader
122-
dataloader_val (DataLoader): Validation DataLoader
123-
save_every (int, optional):
124-
How often to save model state (in epochs). Useful for very small datasets, like growing lizard.
125-
summary_writer (SummaryWriter, optional):
126-
Tensorboard SummaryWriter. Defaults to None.
127-
plot_function (Callable[ [np.ndarray, np.ndarray, np.ndarray, BasicNCAModel], Figure ], optional):
128-
Plot function override. If None, use model's default. Defaults to None.
129-
earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.
130-
131-
Returns:
132-
TrainingSummary: TrainingSummary object.
166+
:param dataloader_train [DataLoader]: Training DataLoader
167+
:param dataloader_val [DataLoader]: Validation DataLoader
168+
:param save_every [int]: How often to save model state (in epochs). Useful for very small datasets, like growing lizard.
169+
:param summary_writer [SummaryWriter] Tensorboard SummaryWriter. Defaults to None.
170+
:param plot_function: Plot function override. If None, use model's default. Defaults to None.
171+
:param earlystopping (EarlyStopping, optional): EarlyStopping object. Defaults to None.
172+
173+
:returns [TrainingSummary]: TrainingSummary object.
133174
"""
134175
logging.basicConfig(encoding="utf-8", level=logging.INFO)
135176

@@ -154,53 +195,6 @@ def train(
154195
else:
155196
best_path = None
156197

157-
def train_iteration(
158-
x: torch.Tensor,
159-
y: torch.Tensor,
160-
steps: int,
161-
optimizer: torch.optim.Optimizer,
162-
scheduler: torch.optim.lr_scheduler.LRScheduler,
163-
total_batch_iterations: int,
164-
) -> torch.Tensor:
165-
"""
166-
Run a single training iteration.
167-
168-
Args:
169-
x (torch.Tensor): Input training images.
170-
y (torch.Tensor): Input training labels.
171-
steps (int): Number of NCA inference time steps.
172-
optimizer (torch.optim.Optimizer): Optimizer.
173-
scheduler (torch.optim.lr_scheduler.LRScheduler): Scheduler.
174-
total_batch_iterations (int): Total training batch iterations
175-
176-
Returns:
177-
torch.Tensor: Predicted image.
178-
"""
179-
device = self.nca.device
180-
self.nca.train()
181-
optimizer.zero_grad()
182-
x_pred = x.clone().to(self.nca.device)
183-
if self.truncate_backprop:
184-
for step in range(steps):
185-
x_pred = self.nca(x_pred, steps=1)
186-
if step < steps - 10:
187-
x_pred.detach()
188-
else:
189-
x_pred = self.nca(x_pred, steps=steps)
190-
losses = self.nca.loss(x_pred, y.to(device))
191-
losses["total"].backward()
192-
193-
if self.gradient_clipping:
194-
torch.nn.utils.clip_grad_norm_(self.nca.parameters(), 1.0)
195-
optimizer.step()
196-
scheduler.step()
197-
if summary_writer:
198-
for key in losses:
199-
summary_writer.add_scalar(
200-
f"Loss/train_{key}", losses[key], total_batch_iterations
201-
)
202-
return x_pred
203-
204198
# MAIN LOOP
205199
total_batch_iterations = 0
206200
for iteration in tqdm(range(self.max_epochs), desc="Epochs"):
@@ -251,16 +245,26 @@ def train_iteration(
251245
y = torch.cat(self.batch_repeat * [y])
252246

253247
steps = np.random.randint(*self.steps_range)
254-
x_pred = train_iteration(
255-
x, y, steps, optimizer, scheduler, total_batch_iterations
248+
x_pred = self.train_iteration(
249+
x,
250+
y,
251+
steps,
252+
optimizer,
253+
scheduler,
254+
total_batch_iterations,
255+
summary_writer,
256256
)
257257
if self.p_retain_pool > 0.0:
258258
x_previous = x_pred
259259
total_batch_iterations += 1
260260

261261
with torch.no_grad():
262262
# VISUALIZATION
263-
if plot_function and summary_writer and (iteration + 1) % save_every == 0:
263+
if (
264+
plot_function
265+
and summary_writer
266+
and (iteration + 1) % save_every == 0
267+
):
264268
figure = plot_function(
265269
x.detach().cpu().numpy(),
266270
x_pred.detach().cpu().numpy(),

0 commit comments

Comments
 (0)