Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,16 +278,16 @@ async def main(cfg: DictConfig):
# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
pad_id = await dataloader.pad_token.choose()
pad_id = await dataloader.pad_token.route()
while True:
sample = await dataloader.sample.choose()
sample = await dataloader.sample.route()
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["request"], sample["target"]
responses = await policy.generate.choose(prompt)
responses = await policy.generate.route(prompt)
# TODO: this shall be part of the responses metadata instead of a separate call
version = await policy.get_version.choose()
version = await policy.get_version.route()
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
Expand All @@ -311,29 +311,29 @@ async def continuous_rollouts():
episode.response = response.text
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
episode.reward = await reward_actor.evaluate_response.choose(
episode.reward = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
)

# Calculate reference logprobs
ref_logits = await ref_model.forward.choose(input_ids)
ref_logits = await ref_model.forward.route(input_ids)
ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
for i, episode in enumerate(group.episodes):
episode.ref_logprobs = ref_logprobs[i]
del ref_logits, ref_logprobs, input_ids

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.choose(group)
advantages = await compute_advantages.compute.route(group)
for episode, advantage in zip(group.episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.choose(episode)
await replay_buffer.add.route(episode)

# Log metrics
avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count)
buffer_size = await replay_buffer._numel.choose()
buffer_size = await replay_buffer._numel.route()
mlogger.log("buffer_size/rollout", buffer_size, rollout_count)
avg_reward = sum(e.reward for e in group.episodes) / group_size
mlogger.log("avg_reward/rollout", avg_reward, rollout_count)
Expand All @@ -343,16 +343,16 @@ async def continuous_rollouts():
async def continuous_training():
training_step = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
batch = await replay_buffer.sample.route(curr_policy_version=training_step)
if batch is None:
await asyncio.sleep(0.1)
else:
inputs, targets = batch
loss = await trainer.train_step.choose(inputs, targets)
loss = await trainer.train_step.route(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.call(training_step)
await policy.update_weights.call(training_step)
await trainer.push_weights.fanout(training_step)
await policy.update_weights.fanout(training_step)

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
Expand Down
6 changes: 3 additions & 3 deletions apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ async def run(cfg: DictConfig):
ref_logprobs=torch.randn((256,), generator=g),
advantage=torch.randn((1,), generator=g),
)
await replay_buffer.add.choose(e)
await replay_buffer.add.route(e)

print("Train step...")
inputs, targets = await replay_buffer.sample.choose(curr_policy_version=0)
outputs = await trainer.train_step.choose(inputs, targets)
inputs, targets = await replay_buffer.sample.route(curr_policy_version=0)
outputs = await trainer.train_step.route(inputs, targets)
print("Loss: ", outputs["loss"])

print("Shutting down...")
Expand Down
6 changes: 3 additions & 3 deletions apps/sft_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,13 @@ async def run(cfg: DictConfig) -> None:
)

logging.info("Created recipe, running setup.")
await recipe.setup.call()
await recipe.setup.fanout()

logging.info("Recipe has been setup. Training now.")
await recipe.train.call()
await recipe.train.fanout()

logging.info("Done training. Clean up")
await recipe.cleanup.call()
await recipe.cleanup.fanout()
await recipe.mesh.stop()
logging.info("All done!")

Expand Down
24 changes: 12 additions & 12 deletions apps/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,16 +464,16 @@ async def main(cfg: DictConfig):
# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
pad_id = await dataloader.pad_token.choose()
pad_id = await dataloader.pad_token.route()
while True:
# Pass rollout_count for curriculum learning
sample = await dataloader.sample.choose(rollout_count)
sample = await dataloader.sample.route(rollout_count)
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["request"], sample["target"]
responses = await policy.generate.choose(prompt)
version = await policy.get_version.choose()
responses = await policy.generate.route(prompt)
version = await policy.get_version.route()
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
Expand All @@ -491,13 +491,13 @@ async def continuous_rollouts():
episode.response_tokens = response.token_ids
episode.response = response.text
episode.response_logprobs = response.logprobs
episode.ref_logprobs = await ref_model.forward.choose(episode)
episode.reward = await reward_actor.evaluate_response.choose(
episode.ref_logprobs = await ref_model.forward.route(episode)
episode.reward = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
)
episode.advantage = episode.reward # simple case for now
for episode in group.episodes:
await replay_buffer.add.choose(episode)
await replay_buffer.add.route(episode)
avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
Expand All @@ -510,18 +510,18 @@ async def continuous_rollouts():
async def continuous_training():
training_step = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
batch = await replay_buffer.sample.route(curr_policy_version=training_step)
if batch is None:
await asyncio.sleep(0.1)
else:
loss = await trainer.train_step.choose(batch[0])
loss = await trainer.train_step.route(batch[0])
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
print(f"loss/training_step: {loss} at training step {training_step}")
await trainer.push_weights.call(training_step)
await policy.update_weights.call(training_step)
await trainer.push_weights.fanout(training_step)
await policy.update_weights.fanout(training_step)
# NOTE: hard-coded to be on-policy for faster convergence
await replay_buffer.clear.call()
await replay_buffer.clear.fanout()

print("Starting training loop.")
# TODO: Start multiple rollouts once all serivces support it
Expand Down
2 changes: 1 addition & 1 deletion apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def run(cfg: DictConfig):
n = 100
start = time.time()
response_outputs: list[Completion] = await asyncio.gather(
*[policy.generate.choose(prompt=prompt) for _ in range(n)]
*[policy.generate.route(prompt=prompt) for _ in range(n)]
)
end = time.time()

Expand Down
8 changes: 4 additions & 4 deletions src/forge/controller/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class ServiceEndpoint(Generic[P, R]):

This loosely mimics the Endpoint APIs exposed in Monarch, with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring would be like,

This extends Monarch's actor APIs:

  • route(): Routes the request to a replica
  • fanout(): Runs the request on all replicas

Note that Monarch's native actor APIs will not apply for services.

a few key differences:
- Only choose and call are retained (dropping stream and call_one)
- Call returns a list directly rather than a ValueMesh.
- Only choose(route) and call(fanout) are retained (dropping stream and call_one)
- Fanout returns a list directly rather than a ValueMesh.

These changes are made with Forge use cases in mind, but can
certainly be expanded/adapted in the future.
Expand All @@ -94,13 +94,13 @@ def __init__(self, service, endpoint_name: str):
self.service = service
self.endpoint_name = endpoint_name

async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R:
async def route(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Chooses a replica to call based on context and load balancing strategy."""
# Extract sess_id from kwargs if present
sess_id = kwargs.pop("sess_id", None)
return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs)

async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
"""Broadcasts a request to all healthy replicas and returns the results as a list."""
result = await self.service.call_all(self.endpoint_name, *args, **kwargs)
return result
Expand Down
12 changes: 6 additions & 6 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,17 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg):
procs=worker_size, with_gpus=True, num_replicas=1
).as_service(**trainer_cfg)

await rl_trainer.push_weights.choose(policy_version=0)
await rl_trainer.push_weights.route(policy_version=0)
# 3. Policy pull weights
policy_config, service_config = get_configs(
worker_size=worker_size, tp_size=worker_size, model_name=self.model
)
policy = await Policy.options(**asdict(service_config)).as_service(
**policy_config
)
await policy.update_weights.call()
await policy.update_weights.fanout()
# 4. Validate weights
loaded_state_dict = await policy._get_model_params.choose()
loaded_state_dict = await policy._get_model_params.route()
validate_loaded_tensors_equals_original(
loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0
)
Expand Down Expand Up @@ -297,21 +297,21 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp):
procs=trainer_worker_size, with_gpus=True, num_replicas=1
).as_service(**trainer_cfg_tp)

await rl_trainer.push_weights.call(policy_version=0)
await rl_trainer.push_weights.fanout(policy_version=0)
# 3. Policy pull weights
policy_config, service_config = get_configs(
worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model
)
policy = await Policy.options(**asdict(service_config)).as_service(
**policy_config
)
await policy.update_weights.call()
await policy.update_weights.fanout()

# validate loaded shard of each worker againt manually calculated shard (expected shard).

# 4. Validate weight shards. We compare vLLM loades shard content with
# Directly loaded HF shard content.
sharded_state_dicts = await policy._get_model_params.call()
sharded_state_dicts = await policy._get_model_params.fanout()
validate_loaded_tensors_equals_original(
sharded_state_dicts[0][0], expected_sd, tensor_parallel_size=tp_size, rank=0
)
Expand Down
12 changes: 4 additions & 8 deletions tests/unit_tests/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@
from forge.actors.replay_buffer import ReplayBuffer
from forge.types import Trajectory

from monarch.actor import proc_mesh


class TestReplayBuffer:
@pytest_asyncio.fixture
async def replay_buffer(self) -> ReplayBuffer:
mesh = await proc_mesh(gpus=1)
replay_buffer = await mesh.spawn(
"replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1
replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=True).as_actor(
batch_size=2, max_policy_age=1
)
await replay_buffer.setup.call()
return replay_buffer
Expand Down Expand Up @@ -112,10 +109,9 @@ async def test_sample_with_evictions(self, replay_buffer) -> None:
@pytest.mark.asyncio
async def test_sample_dp_size(self) -> None:
"""Test that len(samples) == dp_size when sampling."""
mesh = await proc_mesh(gpus=1)
# Create replay buffer with dp_size=3
replay_buffer = await mesh.spawn(
"replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1, dp_size=3
replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=True).as_actor(
batch_size=2, max_policy_age=1, dp_size=3
)
await replay_buffer.setup.call()

Expand Down
Loading
Loading