Skip to content

Commit 606e0fa

Browse files
Merge pull request #125 from mr-yamraj/harsha/reorg
fixed total iterations for rolling training
2 parents 5d0af0f + 514d3c1 commit 606e0fa

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

examples/pytorch/FastCells/train_classifier.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,44 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
166166
trim_level = options.trim_level
167167

168168
ticks = training_data.num_rows / batch_size # iterations per epoch
169-
total_iterations = ticks * num_epochs
170-
scheduler = self.configure_lr(options, optimizer, ticks, total_iterations)
171-
172-
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
173-
log = []
169+
170+
# Calculation of total iterations in non-rolling vs rolling training
171+
# ticks = num_rows/batch_size (total number of iterations per epoch)
172+
# Non-Rolling Training:
173+
# Total Iteration = num_epochs * ticks
174+
# Rolling Training:
175+
# irl = Initial_rolling_length (We are using 2)
176+
# If num_epochs <= max_rolling_length:
177+
# Total Iterations = sum(range(irl, irl + num_epochs))
178+
# If num_epochs > max_rolling_length:
179+
# Total Iterations = sum(range(irl, irl + max_rolling_length)) + (num_epochs - max_rolling_length)*ticks
174180
if options.rolling:
175181
rolling_length = 2
176182
max_rolling_length = int(ticks)
177-
if max_rolling_length > options.max_rolling_length:
178-
max_rolling_length = options.max_rolling_length
183+
if max_rolling_length > options.max_rolling_length + rolling_length:
184+
max_rolling_length = options.max_rolling_length + rolling_length
179185
bag_count = 100
180186
hidden_bag_size = batch_size * bag_count
187+
if num_epochs + rolling_length < max_rolling_length:
188+
max_rolling_length = num_epochs + rolling_length
189+
total_iterations = sum(range(rolling_length, max_rolling_length))
190+
if num_epochs + rolling_length > max_rolling_length:
191+
epochs_remaining = num_epochs + rolling_length - max_rolling_length
192+
total_iterations += epochs_remaining * training_data.num_rows / batch_size
193+
ticks = total_iterations / num_epochs
194+
else:
195+
total_iterations = ticks * num_epochs
196+
197+
scheduler = self.configure_lr(options, optimizer, ticks, total_iterations)
198+
199+
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
200+
log = []
181201

182202
for epoch in range(num_epochs):
183203
self.train()
184204
if options.rolling:
185205
rolling_length += 1
186-
if rolling_length < max_rolling_length:
206+
if rolling_length <= max_rolling_length:
187207
self.init_hidden_bag(hidden_bag_size, device)
188208
for i_batch, (audio, labels) in enumerate(training_data.get_data_loader(batch_size)):
189209
if not self.batch_first:
@@ -197,7 +217,7 @@ def fit(self, training_data, validation_data, options, sparsify=False, device=No
197217
# Also, we need to clear out the hidden state,
198218
# detaching it from its history on the last instance.
199219
if options.rolling:
200-
if rolling_length < max_rolling_length:
220+
if rolling_length <= max_rolling_length:
201221
if (i_batch + 1) % rolling_length == 0:
202222
self.init_hidden()
203223
break

0 commit comments

Comments
 (0)