@@ -36,7 +36,7 @@ class RL4COLitModule(LightningModule):
3636 lr_scheduler_interval: learning rate scheduler interval
3737 lr_scheduler_monitor: learning rate scheduler monitor
3838 generate_default_data: whether to generate default datasets, filling up the data directory
39- shuffle_train_dataloader: whether to shuffle training dataloader
39+ shuffle_train_dataloader: whether to shuffle training dataloader. Default is False since we recreate dataset every epoch
4040 dataloader_num_workers: number of workers for dataloader
4141 data_dir: data directory
4242 metrics: metrics
@@ -50,7 +50,7 @@ def __init__(
5050 batch_size : int = 512 ,
5151 val_batch_size : int = None ,
5252 test_batch_size : int = None ,
53- train_data_size : int = 1_280_000 ,
53+ train_data_size : int = 100_000 ,
5454 val_data_size : int = 10_000 ,
5555 test_data_size : int = 10_000 ,
5656 optimizer : Union [str , torch .optim .Optimizer , partial ] = "Adam" ,
@@ -63,7 +63,7 @@ def __init__(
6363 lr_scheduler_interval : str = "epoch" ,
6464 lr_scheduler_monitor : str = "val/reward" ,
6565 generate_default_data : bool = False ,
66- shuffle_train_dataloader : bool = True ,
66+ shuffle_train_dataloader : bool = False ,
6767 dataloader_num_workers : int = 0 ,
6868 data_dir : str = "data/" ,
6969 log_on_step : bool = True ,
@@ -278,8 +278,12 @@ def on_train_epoch_end(self):
278278 """Called at the end of the training epoch. This can be used for instance to update the train dataset
279279 with new data (which is the case in RL).
280280 """
281- train_dataset = self .env .dataset (self .data_cfg ["train_data_size" ], "train" )
282- self .train_dataset = self .wrap_dataset (train_dataset )
281+ # Only update if not in the first epoch
282+ # If last epoch, we don't need to update since we will not use the dataset anymore
283+ if self .current_epoch < self .trainer .max_epochs - 1 :
284+ log .info ("Generating training dataset for next epoch..." )
285+ train_dataset = self .env .dataset (self .data_cfg ["train_data_size" ], "train" )
286+ self .train_dataset = self .wrap_dataset (train_dataset )
283287
284288 def wrap_dataset (self , dataset ):
285289 """Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to
0 commit comments