@@ -38,7 +38,6 @@ def __init__(
3838 lr_gamma : float = 0.9999 ,
3939 adam_betas = (0.9 , 0.99 ),
4040 batch_repeat : int = 2 ,
41- truncate_backprop : bool = False ,
4241 max_epochs : int = 200 ,
4342 p_retain_pool : float = 0.0 ,
4443 optimizer_method : str = "adamw" ,
@@ -55,7 +54,6 @@ def __init__(
5554 :param lr_gamma (float, optional): Exponential learning rate decay. Defaults to 0.9999.
5655 :param adam_betas (tuple, optional): Beta values for Adam optimizer. Defaults to (0.9, 0.99).
5756 :param batch_repeat (int, optional): How often each batch will be duplicated. Defaults to 2.
58- :param truncate_backprop (bool, optional): Whether to truncate backpropagation. Defaults to False.
5957 :param max_epochs (int, optional): Maximum number of epochs in training. Defaults to 200.
6058 :param p_retain_pool (float, optional): Probability at which a sample will be retained. Defaults to 0.0.
6159 :param optimizer_method: Optimization method. Defaults to 'adamw'.
@@ -77,11 +75,9 @@ def __init__(
7775 self .gradient_clipping = gradient_clipping
7876 self .steps_range = steps_range
7977 self .steps_validation = steps_validation
80- self .lr = lr
8178 self .lr_gamma = lr_gamma
8279 self .adam_betas = adam_betas
8380 self .batch_repeat = batch_repeat
84- self .truncate_backprop = truncate_backprop
8581 self .max_epochs = max_epochs
8682 self .p_retain_pool = p_retain_pool
8783 self .optimizer_method = optimizer_method
@@ -96,6 +92,8 @@ def __init__(
9692 self .lr = 1e-2
9793 else :
9894 self .lr = 1e-2
95+ else :
96+ self .lr = lr
9997
10098 def info (self ) -> str :
10199 """
@@ -113,7 +111,6 @@ def info(self) -> str:
113111 "gradient_clipping" ,
114112 "adam_betas" ,
115113 "batch_repeat" ,
116- "truncate_backprop" ,
117114 "max_epochs" ,
118115 "p_retain_pool" ,
119116 "optimizer_method" ,
@@ -148,13 +145,7 @@ def train_iteration(
148145 self .nca .train ()
149146 optimizer .zero_grad ()
150147 x_pred = x .clone ().to (self .nca .device )
151- if self .truncate_backprop :
152- for step in range (steps ):
153- x_pred = self .nca (x_pred , steps = 1 )
154- if step < steps - 10 :
155- x_pred .detach ()
156- else :
157- x_pred = self .nca (x_pred , steps = steps )
148+ x_pred = self .nca (x_pred , steps = steps )
158149 losses = self .nca .loss (x_pred , y .to (device ))
159150 losses ["total" ].backward ()
160151
0 commit comments