-
Notifications
You must be signed in to change notification settings - Fork 53
Enable single sample processing #1380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
37f586f
d7af544
c7340ab
c756837
37b9c76
d0ec01d
fb1b995
df2b885
f168f05
a67b5c3
349b1ab
f0e09d7
7703a12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,6 +113,8 @@ def __init__( | |
| self.forecast_policy = cf.forecast_policy | ||
|
|
||
| self.len = 100000000 | ||
| self.samples_per_mini_epoch = samples_per_mini_epoch | ||
| self.repeat_data = cf.get("repeat_data", False) | ||
|
|
||
| self.streams_datasets: list[list[AnyDataReader]] = [] | ||
| for _, stream_info in enumerate(cf.streams): | ||
|
|
@@ -186,7 +188,12 @@ def __init__( | |
|
|
||
| index_range = self.time_window_handler.get_index_range() | ||
| self.len = int(index_range.end - index_range.start) | ||
| self.len = min(self.len, samples_per_mini_epoch if samples_per_mini_epoch else self.len) | ||
| if not self.repeat_data: | ||
| self.len = min(self.len, samples_per_mini_epoch if samples_per_mini_epoch else self.len) | ||
| else: | ||
| assert samples_per_mini_epoch, "must specify samples_per_mini_epoch if repeat_data" | ||
| self.len = samples_per_mini_epoch | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if I understand correctly, but do we not want to run an epoch with e.g.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is addressed in line 278, but may not cover some edge cases.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe the underlying question is whether we need to adjust |
||
| # adjust len to split loading across all workers and ensure it is multiple of batch_size | ||
| len_chunk = ((self.len // cf.world_size) // batch_size) * batch_size | ||
| self.len = min(self.len, len_chunk) | ||
|
|
@@ -269,10 +276,16 @@ def reset(self): | |
| idx_end = index_range.end | ||
| # native length of datasets, independent of mini_epoch length that has potentially been | ||
| # specified | ||
| forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs | ||
| forecast_len = (self.len_hrs * (fsm)) // self.step_hrs # NOTE: why was it fsm +1? | ||
| idx_end -= forecast_len + self.forecast_offset | ||
|
|
||
| assert idx_end > 0, "dataset size too small for forecast range" | ||
| self.perms = np.arange(index_range.start, idx_end) | ||
| if self.repeat_data: | ||
| assert self.samples_per_mini_epoch == self.len, "length of sampler was set different from samples_per_mini_epoch –- aborting to avoid unintended effects" | ||
| assert self.len % len(self.perms) == 0, "length of permutations is not a multiple of length of available data –- aborting to avoid unintended effects" | ||
| self.perms = np.tile(self.perms, self.len // len(self.perms)) #TODO: maybe use samples_per_mini_epoch? | ||
|
|
||
| if self.shuffle: | ||
| self.perms = self.rng.permutation(self.perms) | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove this comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's set the
end_dates for train and val to 201401011800, such that we can keep(fsm + 1)below.