diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 89c16b165..8337c6071 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -184,6 +184,7 @@ trainer: critic_mini_batch_size: 256 micro_train_batch_size_per_gpu: 1 micro_forward_batch_size_per_gpu: 1 + max_tokens_per_microbatch: -1 # TODO: Maybe split this between forward and train; -1 means no token-based chunking update_ref_every_epoch: false use_sample_packing: true eval_batch_size: 1024 diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 080b9bb42..c2e07d7a1 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -30,7 +30,7 @@ from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype from skyrl_train.utils.constants import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl_train.training_batch import TrainingOutputBatch -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import BaseBatchIterator, SampleBasedBatchIterator, reduce_metrics from skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -512,7 +512,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the worker MegatronModelWrapper.forward_backward_mini_batch method. """ - dataloader = BatchIterator( + dataloader = SampleBasedBatchIterator( train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) @@ -538,7 +538,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # TODO: Convert this into 2 loops for minibatches and microbatches. micro_buffer = [] for local_step, microbatch in enumerate(pbar): - experience = BatchIterator.batch_to_experience(microbatch) + experience = BaseBatchIterator.batch_to_experience(microbatch) experience.to_device(torch.cuda.current_device()) sequences = experience.sequences attention_mask = experience.attention_mask diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 361a0777e..791e0df65 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -32,7 +32,12 @@ from loguru import logger from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import ( + BaseBatchIterator, + SampleBasedBatchIterator, + TokenBasedBatchIterator, + reduce_metrics, +) from skyrl_train.dataset.replay_buffer import Experience from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient @@ -43,6 +48,15 @@ _SET_AFFINITY = False +def _get_microbatch_iterator( + data: TrainingInputBatch, micro_batch_size: int, max_tokens_per_microbatch: int +) -> BaseBatchIterator: + if max_tokens_per_microbatch > 0: + return TokenBasedBatchIterator(data, max_tokens_per_microbatch=max_tokens_per_microbatch) + else: + return SampleBasedBatchIterator(data, sample_batch_size=micro_batch_size, drop_last=False) + + # Adapted from OpenRLHF: https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/trainer/ray/launcher.py#L17 class DistributedTorchRayActor: def __init__( @@ -305,16 +319,22 @@ def forward( ) -> TrainingOutputBatch: """Run forward pass on the input batch in inference mode. - This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.trainer.micro_forward_batch_size_per_gpu`. + This is a wrapper around `_forward_micro_batch` that runs in micro batches. + Uses token-based chunking if `max_tokens_per_microbatch` is configured, otherwise + falls back to sample-based chunking with `micro_forward_batch_size_per_gpu`. """ - # run in micro batches of cfg.trainer.micro_forward_batch_size_per_gpu # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc. - micro_batches = data.chunk(self.cfg.trainer.micro_forward_batch_size_per_gpu) + microbatch_iterator = _get_microbatch_iterator( + data, + micro_batch_size=self.cfg.trainer.micro_forward_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, + ) outputs = [] - for micro_batch in micro_batches: - outputs.append(self._forward_micro_batch(micro_batch)) - output = TrainingOutputBatch.cat(outputs) + for microbatch in microbatch_iterator: + outputs.append(self._forward_micro_batch(microbatch)) + output = microbatch_iterator.reorder_and_combine_batches(outputs) + if output.device is not None and output.device != torch.device("cpu"): output = output.to("cpu") return output @@ -724,7 +744,7 @@ def optim_step(self) -> float: def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: global_step = train_data.metadata["global_step"] - minibatch_iterator = BatchIterator( + minibatch_iterator = SampleBasedBatchIterator( train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False ) @@ -784,16 +804,23 @@ def record_status(status: Dict[str, float]): disable=not self.strategy.is_rank_0(), ) for minibatch in minibatch_pbar: - microbatch_iterator = BatchIterator( - minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False + microbatch_iterator = _get_microbatch_iterator( + minibatch, + micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, ) num_microbatches = len(microbatch_iterator) - microbatch_weight = 1.0 / num_microbatches + temp = [] for microbatch_idx, microbatch in enumerate(microbatch_iterator): - microbatch_experience = BatchIterator.batch_to_experience(microbatch) + microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) + microbatch_weight = len(microbatch) / len(minibatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) + print("!!! loss", self.mesh_rank, status["final_loss"]) + + temp.append(microbatch["attention_mask"].sum().item()) + # Record status for all but the last microbatch in the minibatch. # The last microbatch should be recorded after the optimizer step. if microbatch_idx < num_microbatches - 1: @@ -808,6 +835,8 @@ def record_status(status: Dict[str, float]): if grad_norm is not None: status["raw_grad_norm"] = grad_norm + # print(self.mesh_rank, temp) + if self.record_memory: self.save_memory_snapshot(global_step, local_step) @@ -991,7 +1020,7 @@ def save_hf_model(self, export_dir: str, tokenizer): def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: global_step = train_data.metadata["global_step"] - minibatch_iterator = BatchIterator( + minibatch_iterator = SampleBasedBatchIterator( train_data, sample_batch_size=self.critic_mini_batch_size_per_gpu, drop_last=False ) @@ -1019,14 +1048,16 @@ def record_status(status: Dict[str, float]): disable=not self.strategy.is_rank_0(), ) for minibatch in minibatch_pbar: - microbatch_iterator = BatchIterator( - minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False + microbatch_iterator = _get_microbatch_iterator( + minibatch, + micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, ) num_microbatches = len(microbatch_iterator) - microbatch_weight = 1.0 / num_microbatches for microbatch_idx, microbatch in enumerate(microbatch_iterator): - microbatch_experience = BatchIterator.batch_to_experience(microbatch) + microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) + microbatch_weight = len(microbatch) / len(minibatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) if microbatch_idx < num_microbatches - 1: diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 897d032ea..2913635a5 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -1,7 +1,12 @@ +import heapq import math +from typing import List, Dict, Iterator + +import torch +import torch.distributed as dist + from skyrl_train.dataset.replay_buffer import Experience -from typing import List, Dict -from skyrl_train.training_batch import TrainingInputBatch +from skyrl_train.training_batch import TrainingInputBatch, TensorBatch def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: @@ -16,33 +21,19 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: return reduced_metrics -class BatchIterator: - """A simple iterator to yield micro batches of data from the training batch.""" - - def __init__(self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False): +class BaseBatchIterator: + def __init__(self, data: TrainingInputBatch): self.data = data - self.sample_batch_size = sample_batch_size - self.total_batch_size = data.batch_size - self.drop_last = drop_last - assert not drop_last, "drop_last is not supported yet" - num_micro_batches = self.total_batch_size / self.sample_batch_size - self.num_micro_batches = int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) - # TODO: switch to tensordict.map_iter if possible - self._chunks = self.data.chunk(self.sample_batch_size) - self._iter = iter(self._chunks) def __len__(self): - return self.num_micro_batches + raise NotImplementedError - def __iter__(self): - return self + def __iter__(self) -> Iterator[TrainingInputBatch]: + raise NotImplementedError - def __next__(self) -> TrainingInputBatch: - try: - return next(self._iter) - except StopIteration: - self._iter = iter(self._chunks) - raise StopIteration + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Reorder and combine output batches to form a single output.""" + raise NotImplementedError @staticmethod def batch_to_experience(batch: TrainingInputBatch): @@ -67,3 +58,246 @@ def batch_to_experience(batch: TrainingInputBatch): metadata=batch.metadata, ) return exp + + +class SampleBasedBatchIterator(BaseBatchIterator): + """A simple iterator to yield batches of the same size from the training input.""" + + def __init__(self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False): + self.data = data + self.sample_batch_size = sample_batch_size + self.total_batch_size = data.batch_size + self.drop_last = drop_last + assert not drop_last, "drop_last is not supported yet" + num_micro_batches = self.total_batch_size / self.sample_batch_size + self.num_micro_batches = int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) + # TODO: switch to tensordict.map_iter if possible + self._chunks = self.data.chunk(self.sample_batch_size) + self._iter = iter(self._chunks) + + def __len__(self): + return self.num_micro_batches + + def __iter__(self): + return self + + def __next__(self) -> TrainingInputBatch: + try: + return next(self._iter) + except StopIteration: + self._iter = iter(self._chunks) + raise StopIteration + + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Concatenate output batches to form a single output. + + This iterator evenly splits the input into batches of the same size, + so there's no need to reorder.""" + return TensorBatch.cat(batches) + + +def balanced_binpacking(token_counts: List[int], max_tokens_per_microbatch: int) -> List[List[int]]: + """Chunk a list of token counts into microbatches so that each + microbatch's total token count does not exceed `max_tokens_per_microbatch`, + and the microbatches roughly balanced. + + Roughly balance by assigning sequences to the microbatch with + the least number of tokens so far. + + Args: + token_counts: List of token counts for each sample. + max_tokens_per_microbatch: Maximum total tokens allowed per microbatch. + + Returns: + A list of microbatches, where each microbatch is a list of indices (ints) + referring to entries in `token_counts`. + + >>> balanced_binpacking([10, 10, 5, 5], 15) + [[0, 2], [1, 3]] + >>> balanced_binpacking([10, 1, 1, 1, 1, 1], 10) + [[0], [1, 2, 3, 4, 5]] + >>> balanced_binpacking([8, 3, 5, 6, 2, 7], 11) + [[0, 4], [5, 1], [3, 2]] + """ + # TODO: Handle max(token_counts) > max_tokens_per_microbatch + + # Create list of (index, token_count) pairs and sort by token count descending + seq_lens = [(i, seq_len) for i, seq_len in enumerate(token_counts)] + seq_lens.sort(key=lambda x: x[1], reverse=True) + + # Track microbatch indices and their current token counts + microbatch_indices: List[List[int]] = [] + + # Heap to track the total number of tokens in each microbatch + microbatch_tokens_heap = [] # (current_total, bin_idx) + + for idx, seq_len in seq_lens: + placed = False + + # Look for an existing microbatch with the least number of tokens + # that can fit the sequence without exceeding the token limit. + if microbatch_tokens_heap: + microbatch_len, i = microbatch_tokens_heap[0] + new_microbatch_len = microbatch_len + seq_len + if new_microbatch_len <= max_tokens_per_microbatch: + microbatch_indices[i].append(idx) + heapq.heapreplace(microbatch_tokens_heap, (new_microbatch_len, i)) + placed = True + + # If no microbatch can fit the sequence, create a new microbatch. + if not placed: + microbatch_indices.append([idx]) + heapq.heappush(microbatch_tokens_heap, (seq_len, len(microbatch_indices) - 1)) + + return microbatch_indices + + +class TokenBasedBatchIterator(BaseBatchIterator): + """An iterator that chunks microbatches based on real token count. + Packs samples into microbatches, ensuring each microbatch doesn't exceed + max_tokens_per_microbatch. All data parallel workers will have the same number of + microbatches (where padding microbatches are added if needed). + """ + + def __init__( + self, + data: TrainingInputBatch, + max_tokens_per_microbatch: int, + ): + """ + Args: + data: The training input batch to chunk + max_tokens_per_microbatch: Maximum number of tokens per microbatch + """ + self._data = data + self._max_tokens_per_microbatch = max_tokens_per_microbatch + + # Compute token counts per sample using attention_mask + attention_mask = data["attention_mask"] + # Count non-padding tokens per sample + self._token_counts = attention_mask.sum(dim=1).cpu().tolist() # [batch_size] + + # Create microbatches based on token count + # TODO: Allow for different chunking strategies. + self._microbatches = balanced_binpacking(self._token_counts, self._max_tokens_per_microbatch) + + # Synchronize the number of microbatches across all DP workers + max_num_microbatches = self._sync_num_microbatches() + + self._num_padding_microbatches = max_num_microbatches - len(self._microbatches) + + @property + def num_microbatches(self) -> int: + return len(self._microbatches) + self._num_padding_microbatches + + def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch: + """Create a TrainingInputBatch from a list of sample indices.""" + indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu") + selected_data = {} + for key, value in self._data.items(): + selected_data[key] = value[indices_tensor] + microbatch = TrainingInputBatch(selected_data) + microbatch.metadata = self._data.metadata + return microbatch + + def _create_padding_microbatch(self) -> TrainingInputBatch: + """Create a padding microbatch.""" + seq_len = 2 + num_actions = self._data.metadata["response_length"] + batch_size = 1 + + data = TrainingInputBatch( + { + "sequences": torch.randint(0, 100, (batch_size, seq_len), device="cpu"), + "attention_mask": torch.ones((batch_size, seq_len), dtype=int, device="cpu"), + "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), + "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), + "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), + # Loss mask is all zeros because we don't want padding samples to contribute to the loss. + "loss_mask": torch.zeros((batch_size, num_actions), dtype=int, device="cpu"), + "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + } + ) + data.metadata = self._data.metadata + return data + + def _sync_num_microbatches(self) -> int: + """Ensure all DP workers have the same number of micro batches.""" + local_num_microbatches = len(self._microbatches) + + if not dist.is_initialized(): + return local_num_microbatches + + # Get the maximum number of batches across all DP workers + if torch.cuda.is_available(): + device = torch.cuda.current_device() + else: + device = torch.device("cpu") + num_microbatches_tensor = torch.tensor(local_num_microbatches, dtype=torch.long, device=device) + dist.all_reduce(num_microbatches_tensor, op=dist.ReduceOp.MAX) + return num_microbatches_tensor.item() + + def __len__(self): + return len(self._microbatches) + self._num_padding_microbatches + + def __iter__(self): + for microbatch in self._microbatches: + microbatch_data = self._create_microbatch_from_indices(microbatch) + yield microbatch_data + + for _ in range(self._num_padding_microbatches): + padding_microbatch = self._create_padding_microbatch() + yield padding_microbatch + + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Reorder and combine output batches into a single batch with + the same order as the original input data. + + Example: [[0, 2], [1, 3]] -> [0, 1, 2, 3] + + Args: + batches: List of microbatches to reorder. + Returns: + A single reordered batch. + """ + non_padding_microbatches = batches[: len(batches) - self._num_padding_microbatches] + + if not non_padding_microbatches: + # TODO: Can this happen? + raise ValueError("Cannot reorder an empty list of microbatches.") + + # Create a reverse mapping of original idx -> (microbatch idx, sample idx) + original_idx_to_microbatch_idx = {} + + for microbatch_idx, original_indices in enumerate(self._microbatches): + for sample_idx, original_idx in enumerate(original_indices): + original_idx_to_microbatch_idx[original_idx] = (microbatch_idx, sample_idx) + + # Get reference microbatch to know keys and tensor shapes + ref_microbatch = non_padding_microbatches[0] + reordered_data = {} + + for key, ref_value in ref_microbatch.items(): + # Get shape of a single sample (remove batch dimension) + sample_shape = ref_value.shape[1:] + device = ref_value.device + dtype = ref_value.dtype + + # Pre-allocate output tensor: [batch_size, *sample_shape] + batch_size = len(self._token_counts) + output_tensor = torch.zeros((batch_size, *sample_shape), dtype=dtype, device=device) + + # Copy each sample directly into the correct position + for original_idx in range(batch_size): + microbatch_idx, sample_idx = original_idx_to_microbatch_idx[original_idx] + source_tensor = non_padding_microbatches[microbatch_idx][key] + output_tensor[original_idx] = source_tensor[sample_idx] + + reordered_data[key] = output_tensor + + # Create single TensorBatch with reordered data + reordered_batch = type(ref_microbatch)(reordered_data) + reordered_batch.metadata = ref_microbatch.metadata + return reordered_batch diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 599f8a3f4..140da8df9 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -14,7 +14,7 @@ from skyrl_train.training_batch import TrainingInputBatch import numpy as np from skyrl_train.workers.worker import PolicyWorkerBase, CriticWorkerBase -from skyrl_train.workers.worker_utils import BatchIterator +from skyrl_train.workers.worker_utils import SampleBasedBatchIterator from skyrl_train.utils.utils import validate_batch_sizes from skyrl_train.config.utils import get_default_config from tests.cpu.util import example_dummy_config @@ -559,7 +559,7 @@ def mock_policy_forward_backward(experience, microbatch_weight): policy_worker.record_memory = False # Calculate expected values based on new accumulation logic - dataloader = BatchIterator( + dataloader = SampleBasedBatchIterator( dummy_databatch, sample_batch_size=cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) total_micro_batches = len(dataloader) # Should be 6 diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py index c9880a017..9d1782fac 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py @@ -218,3 +218,75 @@ def test_gradient_accumulation_scenarios( print(f" - Actual optimizer steps: {actual_optimizer_steps}") finally: ray.shutdown() + + +import torch +from skyrl_train.training_batch import TrainingInputBatch + +def make_dummy_batch(seq_lens, num_actions=4) -> TrainingInputBatch: + """Create a dummy TrainingInputBatch""" + + batch_size = len(seq_lens) + max_seq_len = max(seq_lens) + + sequences = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + attention_mask = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + for i, seq_len in enumerate(seq_lens): + sequences[i, :seq_len] = torch.randint(0, 100, (seq_len,), dtype=int, device="cpu") + attention_mask[i, :seq_len] = 1 + + print(sequences) + print(attention_mask) + + # Add all the required fields for training + data = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), + "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), + "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), + "loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + } + ) + data.metadata = {"response_length": num_actions} + return data + + + +@pytest.mark.parametrize("worker_type", ["policy", "critic"]) +def test_max_tokens_per_microbatch(ray_init_fixture, cfg, worker_type): + try: + cfg.trainer.strategy = "fsdp2" # Strategy logic is not tested here. + cfg.trainer.max_tokens_per_microbatch = 15 + + # Hard-code to a single worker for simplicity + cfg.trainer.placement.policy_num_gpus_per_node = 2 + cfg.trainer.policy_mini_batch_size = 8 + + actor_group = init_worker_with_type( + worker_type, + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, + cfg=cfg, + ) + + train_data = make_dummy_batch([10, 10, 6, 5, 10, 10, 5, 5], num_actions=4) + # Expect: + # - dp=0: 3 microbatches with [10], [10], [11] + # - dp=1: 3 microbatches with [10, 5], [10, 5], [] + train_data.metadata["global_step"] = 0 + + results = ray.get(actor_group.async_run_ray_method("mesh", "forward", train_data)) + results = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", train_data)) + assert len(results) == cfg.trainer.placement.policy_num_gpus_per_node, "Should get result from each GPU" + + print(results) + # TODO: Add assertions for the results. + + finally: + ray.shutdown()