Skip to content

Commit 14d072e

Browse files
committed
[Feat,BugFix] avoid shuffling train data by default; do not evaluate rollout on very last epoch
1 parent 29be07f commit 14d072e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

rl4co/models/rl/common/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class RL4COLitModule(LightningModule):
3636
lr_scheduler_interval: learning rate scheduler interval
3737
lr_scheduler_monitor: learning rate scheduler monitor
3838
generate_default_data: whether to generate default datasets, filling up the data directory
39-
shuffle_train_dataloader: whether to shuffle training dataloader
39+
shuffle_train_dataloader: whether to shuffle training dataloader. Default is False since we recreate dataset every epoch
4040
dataloader_num_workers: number of workers for dataloader
4141
data_dir: data directory
4242
metrics: metrics
@@ -50,7 +50,7 @@ def __init__(
5050
batch_size: int = 512,
5151
val_batch_size: int = None,
5252
test_batch_size: int = None,
53-
train_data_size: int = 1_280_000,
53+
train_data_size: int = 100_000,
5454
val_data_size: int = 10_000,
5555
test_data_size: int = 10_000,
5656
optimizer: Union[str, torch.optim.Optimizer, partial] = "Adam",
@@ -63,7 +63,7 @@ def __init__(
6363
lr_scheduler_interval: str = "epoch",
6464
lr_scheduler_monitor: str = "val/reward",
6565
generate_default_data: bool = False,
66-
shuffle_train_dataloader: bool = True,
66+
shuffle_train_dataloader: bool = False,
6767
dataloader_num_workers: int = 0,
6868
data_dir: str = "data/",
6969
log_on_step: bool = True,
@@ -278,8 +278,12 @@ def on_train_epoch_end(self):
278278
"""Called at the end of the training epoch. This can be used for instance to update the train dataset
279279
with new data (which is the case in RL).
280280
"""
281-
train_dataset = self.env.dataset(self.data_cfg["train_data_size"], "train")
282-
self.train_dataset = self.wrap_dataset(train_dataset)
281+
# Only update if not in the first epoch
282+
# If last epoch, we don't need to update since we will not use the dataset anymore
283+
if self.current_epoch < self.trainer.max_epochs - 1:
284+
log.info("Generating training dataset for next epoch...")
285+
train_dataset = self.env.dataset(self.data_cfg["train_data_size"], "train")
286+
self.train_dataset = self.wrap_dataset(train_dataset)
283287

284288
def wrap_dataset(self, dataset):
285289
"""Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to

0 commit comments

Comments
 (0)