diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 145b6cd48..d2b49a835 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -19,6 +19,8 @@ from forge.data.rewards import MathReward, ThinkingReward from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint +from torchstore import MultiProcessStore +from torchstore._state_dict_utils import DELIM, push_state_dict from transformers import AutoModelForCausalLM, AutoTokenizer logger = logging.getLogger(__name__) @@ -87,11 +89,15 @@ def __init__( beta: float = 0.1, model_name: str = "", device: torch.device | None = None, + store: MultiProcessStore | None = None, + state_dict_key: str = "model_state_dict", ): super().__init__() self.learning_rate = learning_rate self.beta = beta # KL penalty coefficient self.model_name = model_name + self.store = store + self.state_dict_key = state_dict_key # Set device if device is None: @@ -189,29 +195,15 @@ async def train_step(self, batch: list[Episode]): return {"loss": avg_loss, "groups_processed": num_groups_processed} @endpoint - async def update_weights(self, policy_actor): + async def push_weights(self, version: int): """Update policy model weights with trainer's current weights.""" - # Time how long it takes to update weights start_time = time.time() - - # Set model to eval mode for weight extraction - self.model.eval() - - # Extract current model state dict - model_state_dict = self.model.state_dict() - - # Convert tensors to CPU for transfer (if they're on GPU) - cpu_state_dict = {} - for key, tensor in model_state_dict.items(): - cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor - - # Update the policy actor's model weights - await policy_actor.update_model_weights.choose(cpu_state_dict) - - # Set model back to training mode - self.model.train() - - # Log the time taken + assert self.store is not None, "Store must be provided to save weights" + await push_state_dict( + self.store, + self.model.state_dict(), + f"{self.state_dict_key}{DELIM}{version}", # Use version as key + ) end_time = time.time() self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") @@ -463,8 +455,11 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 + policy_version = 0 while True: - batch = await replay_buffer.sample.choose(curr_policy_version=0) + batch = await replay_buffer.sample.choose( + curr_policy_version=policy_version + ) if batch is None: await asyncio.sleep(0.1) else: @@ -476,7 +471,8 @@ async def continuous_training(): loss_value = training_result.get("loss", 0.0) print(f"Latest loss: {loss_value}") logger.log("loss/training_step", loss_value, training_step) - # await trainer.update_weights(policy) + # await trainer.push_weights.choose(policy_version) + # policy_version += 1 print("Starting GRPO training loops...") rollout_task = asyncio.create_task(continuous_rollouts())