@@ -166,24 +166,44 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
166
166
trim_level = options .trim_level
167
167
168
168
ticks = training_data .num_rows / batch_size # iterations per epoch
169
- total_iterations = ticks * num_epochs
170
- scheduler = self .configure_lr (options , optimizer , ticks , total_iterations )
171
-
172
- # optimizer = optim.Adam(model.parameters(), lr=0.0001)
173
- log = []
169
+
170
+ # Calculation of total iterations in non-rolling vs rolling training
171
+ # ticks = num_rows/batch_size (total number of iterations per epoch)
172
+ # Non-Rolling Training:
173
+ # Total Iteration = num_epochs * ticks
174
+ # Rolling Training:
175
+ # irl = Initial_rolling_length (We are using 2)
176
+ # If num_epochs <= max_rolling_length:
177
+ # Total Iterations = sum(range(irl, irl + num_epochs))
178
+ # If num_epochs > max_rolling_length:
179
+ # Total Iterations = sum(range(irl, irl + max_rolling_length)) + (num_epochs - max_rolling_length)*ticks
174
180
if options .rolling :
175
181
rolling_length = 2
176
182
max_rolling_length = int (ticks )
177
- if max_rolling_length > options .max_rolling_length :
178
- max_rolling_length = options .max_rolling_length
183
+ if max_rolling_length > options .max_rolling_length + rolling_length :
184
+ max_rolling_length = options .max_rolling_length + rolling_length
179
185
bag_count = 100
180
186
hidden_bag_size = batch_size * bag_count
187
+ if num_epochs + rolling_length < max_rolling_length :
188
+ max_rolling_length = num_epochs + rolling_length
189
+ total_iterations = sum (range (rolling_length , max_rolling_length ))
190
+ if num_epochs + rolling_length > max_rolling_length :
191
+ epochs_remaining = num_epochs + rolling_length - max_rolling_length
192
+ total_iterations += epochs_remaining * training_data .num_rows / batch_size
193
+ ticks = total_iterations / num_epochs
194
+ else :
195
+ total_iterations = ticks * num_epochs
196
+
197
+ scheduler = self .configure_lr (options , optimizer , ticks , total_iterations )
198
+
199
+ # optimizer = optim.Adam(model.parameters(), lr=0.0001)
200
+ log = []
181
201
182
202
for epoch in range (num_epochs ):
183
203
self .train ()
184
204
if options .rolling :
185
205
rolling_length += 1
186
- if rolling_length < max_rolling_length :
206
+ if rolling_length <= max_rolling_length :
187
207
self .init_hidden_bag (hidden_bag_size , device )
188
208
for i_batch , (audio , labels ) in enumerate (training_data .get_data_loader (batch_size )):
189
209
if not self .batch_first :
@@ -197,7 +217,7 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
197
217
# Also, we need to clear out the hidden state,
198
218
# detaching it from its history on the last instance.
199
219
if options .rolling :
200
- if rolling_length < max_rolling_length :
220
+ if rolling_length <= max_rolling_length :
201
221
if (i_batch + 1 ) % rolling_length == 0 :
202
222
self .init_hidden ()
203
223
break
0 commit comments