Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM, push_state_dict
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,11 +89,15 @@ def __init__(
beta: float = 0.1,
model_name: str = "",
device: torch.device | None = None,
store: MultiProcessStore | None = None,
state_dict_key: str = "model_state_dict",
):
super().__init__()
self.learning_rate = learning_rate
self.beta = beta # KL penalty coefficient
self.model_name = model_name
self.store = store
self.state_dict_key = state_dict_key

# Set device
if device is None:
Expand Down Expand Up @@ -189,29 +195,15 @@ async def train_step(self, batch: list[Episode]):
return {"loss": avg_loss, "groups_processed": num_groups_processed}

@endpoint
async def update_weights(self, policy_actor):
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
# Time how long it takes to update weights
start_time = time.time()

# Set model to eval mode for weight extraction
self.model.eval()

# Extract current model state dict
model_state_dict = self.model.state_dict()

# Convert tensors to CPU for transfer (if they're on GPU)
cpu_state_dict = {}
for key, tensor in model_state_dict.items():
cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor

# Update the policy actor's model weights
await policy_actor.update_model_weights.choose(cpu_state_dict)

# Set model back to training mode
self.model.train()

# Log the time taken
assert self.store is not None, "Store must be provided to save weights"
await push_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

help me understand how RLTrainer + torchstore integration I'm doing overlaps with this code. @pbontrager

self.store,
self.model.state_dict(),
f"{self.state_dict_key}{DELIM}{version}", # Use version as key
)
end_time = time.time()
self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds")

Expand Down Expand Up @@ -463,8 +455,11 @@ async def continuous_rollouts():

async def continuous_training():
training_step = 0
policy_version = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=0)
batch = await replay_buffer.sample.choose(
curr_policy_version=policy_version
)
if batch is None:
await asyncio.sleep(0.1)
else:
Expand All @@ -476,7 +471,8 @@ async def continuous_training():
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)
# await trainer.update_weights(policy)
# await trainer.push_weights.choose(policy_version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove before mege

# policy_version += 1

print("Starting GRPO training loops...")
rollout_task = asyncio.create_task(continuous_rollouts())
Expand Down
Loading