Skip to content

Commit 0d145d2

Browse files
committed
fix type error in trainer
1 parent 773611b commit 0d145d2

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

ncalab/training/trainer.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def __init__(
3838
lr_gamma: float = 0.9999,
3939
adam_betas=(0.9, 0.99),
4040
batch_repeat: int = 2,
41-
truncate_backprop: bool = False,
4241
max_epochs: int = 200,
4342
p_retain_pool: float = 0.0,
4443
optimizer_method: str = "adamw",
@@ -55,7 +54,6 @@ def __init__(
5554
:param lr_gamma (float, optional): Exponential learning rate decay. Defaults to 0.9999.
5655
:param adam_betas (tuple, optional): Beta values for Adam optimizer. Defaults to (0.9, 0.99).
5756
:param batch_repeat (int, optional): How often each batch will be duplicated. Defaults to 2.
58-
:param truncate_backprop (bool, optional): Whether to truncate backpropagation. Defaults to False.
5957
:param max_epochs (int, optional): Maximum number of epochs in training. Defaults to 200.
6058
:param p_retain_pool (float, optional): Probability at which a sample will be retained. Defaults to 0.0.
6159
:param optimizer_method: Optimization method. Defaults to 'adamw'.
@@ -77,11 +75,9 @@ def __init__(
7775
self.gradient_clipping = gradient_clipping
7876
self.steps_range = steps_range
7977
self.steps_validation = steps_validation
80-
self.lr = lr
8178
self.lr_gamma = lr_gamma
8279
self.adam_betas = adam_betas
8380
self.batch_repeat = batch_repeat
84-
self.truncate_backprop = truncate_backprop
8581
self.max_epochs = max_epochs
8682
self.p_retain_pool = p_retain_pool
8783
self.optimizer_method = optimizer_method
@@ -96,6 +92,8 @@ def __init__(
9692
self.lr = 1e-2
9793
else:
9894
self.lr = 1e-2
95+
else:
96+
self.lr = lr
9997

10098
def info(self) -> str:
10199
"""
@@ -113,7 +111,6 @@ def info(self) -> str:
113111
"gradient_clipping",
114112
"adam_betas",
115113
"batch_repeat",
116-
"truncate_backprop",
117114
"max_epochs",
118115
"p_retain_pool",
119116
"optimizer_method",
@@ -148,13 +145,7 @@ def train_iteration(
148145
self.nca.train()
149146
optimizer.zero_grad()
150147
x_pred = x.clone().to(self.nca.device)
151-
if self.truncate_backprop:
152-
for step in range(steps):
153-
x_pred = self.nca(x_pred, steps=1)
154-
if step < steps - 10:
155-
x_pred.detach()
156-
else:
157-
x_pred = self.nca(x_pred, steps=steps)
148+
x_pred = self.nca(x_pred, steps=steps)
158149
losses = self.nca.loss(x_pred, y.to(device))
159150
losses["total"].backward()
160151

0 commit comments

Comments
 (0)