Skip to content

Commit 4c1dcf9

Browse files
committed
Wipe cache if wrong len
1 parent 4839584 commit 4c1dcf9

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

fms_fsdp/utils/dataset_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -677,10 +677,17 @@ def __iter__(self):
677677
dataset = iter(self.dataset)
678678
# Pad out buffer if needed
679679
self._pad_buffer()
680+
first_draw = next(dataset)
680681
while True:
682+
# If buffer entries have wrong length, reset buffer
683+
if len(first_draw) != len(self.buffer[0]):
684+
self.buffer = []
685+
self.buffer_size = 0
686+
self._pad_buffer()
687+
681688
# If buffer is undersized, add a datapoint
682689
if self.buffer_size < self.window_size:
683-
self.buffer[self.buffer_size] = next(dataset)
690+
self.buffer[self.buffer_size] = next(dataset) if self.buffer_size > 0 else first_draw
684691
self.buffer_size += 1
685692

686693
# Swap out randomly sampled value from buffer.
@@ -696,10 +703,10 @@ def __iter__(self):
696703
yield out
697704

698705
def _pad_buffer(self):
699-
if self.buffer_size < self.window_size:
706+
if len(self.buffer) < self.window_size:
700707
self.buffer += [
701708
[],
702-
] * (self.window_size - self.buffer_size)
709+
] * (len(self.buffer) - self.buffer_size)
703710

704711
def state_dict(self):
705712
# Write generator state manually

fms_fsdp/utils/train_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def train(
8484
for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1):
8585
if batch_idx > cfg.num_steps:
8686
break
87-
if rank == 0:
88-
print(input.shape)
8987
input = input.to(local_rank)
9088
label = label.to(local_rank)
9189

0 commit comments

Comments
 (0)