From 88358a98cfd6acebb6792169cc42f2acdc0807a6 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Wed, 20 Aug 2025 11:38:47 -0700 Subject: [PATCH 01/15] Pushing Policy Worker:rough --- src/forge/actors/policy.py | 96 +++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 43 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 6866a6949..c096a1657 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -14,6 +14,7 @@ import torch +from forge.interfaces import Policy as PolicyInterface from monarch.actor import Actor, current_rank, endpoint, proc_mesh from torchstore import MultiProcessStore @@ -45,18 +46,21 @@ @dataclass -class PolicyRouter(Actor): +class Policy(PolicyInterface): # TODO: Add dp support - policy: Actor + policy_worker: Actor = None sampling_params: SamplingParams = None lora_request: LoRARequest = None tokenization_kwargs: dict = None @endpoint - async def setup(self): + async def setup(self, config, guided_decoding=False, num_samples=1): + # Set up workers + await self.setupWorker(config, guided_decoding, num_samples) + self.request_id = 0 self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {} - self.vllm_args = await self.policy.get_vllm_args.choose() + self.vllm_args = await self.policy_worker.get_vllm_args.choose() # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` @@ -72,7 +76,7 @@ async def setup(self): # Setup schduuler # TODO: Add support for `log_stats` - kv_cache_configs = await self.policy.setup_kv_cache.call() + kv_cache_configs = await self.policy_worker.setup_kv_cache.call() kv_cache_config = kv_cache_configs._values[0] self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks self.vllm_args.cache_config.num_cpu_blocks = 0 @@ -86,6 +90,34 @@ async def setup(self): log_stats=None, ) + async def setupWorker(self, config, guided_decoding, num_samples): + self.worker_mesh = await proc_mesh( + gpus=config["resources"], + env={ + "MASTER_ADDR": str(get_loopback_ip()), + "MASTER_PORT": str(get_open_port()), + }, + ) + self.policy_worker = await self.worker_mesh.spawn( + "policy_worker", PolicyWorker, **config + ) + + # TODO: Make this customizable from the config + vllm_args = await self.policy_worker.get_vllm_args.choose() + overrides = { + "n": num_samples, + "guided_decoding": ( + GuidedDecodingParams(choice=["Positive", "Negative"]) + if guided_decoding + else None + ), + } + self.sampling_params = get_default_sampling_params( + vllm_args, overrides=overrides + ) + + await self.policy_worker.setup.call() + @endpoint async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: self.request_id += 1 % sys.maxsize @@ -169,7 +201,9 @@ async def run(self): self.running = True while self.running: scheduler_output = self.scheduler.schedule() - worker_outputs = await self.policy.execute_model.call(scheduler_output) + worker_outputs = await self.policy_worker.execute_model.call( + scheduler_output + ) worker_output = worker_outputs._values[output_rank] outputs = self.scheduler.update_from_output(scheduler_output, worker_output) outputs = outputs.get(0) or EngineCoreOutputs() @@ -185,13 +219,18 @@ async def run(self): _, fut = self.requests.pop(request_output.request_id) fut.set_result(request_output.outputs) + @endpoint + async def update_weights(self): + """Update the policy weights.""" + pass + @endpoint async def shutdown(self): self.running = False @dataclass -class Policy(Actor): +class PolicyWorker(Actor): model: str tensor_parallel_size: int = 1 pipeline_parallel_size: int = 1 @@ -389,46 +428,17 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: async def _test(config, guided_decoding=False, num_samples=1): # TODO: Create proper test - router_mesh = await proc_mesh(gpus=1) - policy_mesh = await proc_mesh( - gpus=config["resources"], - env={ - "MASTER_ADDR": str(get_loopback_ip()), - "MASTER_PORT": str(get_open_port()), - }, - ) - - policy_actor = await policy_mesh.spawn("policy", Policy, **config) - - # TODO: Make this customizable from the config - overrides = { - "n": num_samples, - "guided_decoding": ( - GuidedDecodingParams(choice=["Positive", "Negative"]) - if guided_decoding - else None - ), - } - - vllm_args = await policy_actor.get_vllm_args.choose() - sampling_params = get_default_sampling_params(vllm_args, overrides=overrides) - - router = await router_mesh.spawn( - "policy_router", - PolicyRouter, - policy=policy_actor, - sampling_params=sampling_params, - ) + policy_mesh = await proc_mesh(gpus=1) + policy = await policy_mesh.spawn("policy", Policy) - await policy_actor.setup.call() - await router.setup.call() print("Model setup") + await policy.setup.call(config, guided_decoding, num_samples) - router.run.call() print("Model running") + policy.run.call() prompt = "What is 3+5?" if guided_decoding else "Tell me a joke" - responses: List[CompletionOutput] = await router.generate.call_one(prompt) + responses: List[CompletionOutput] = await policy.generate.call_one(prompt) for batch, response in enumerate(responses): print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") @@ -436,7 +446,7 @@ async def _test(config, guided_decoding=False, num_samples=1): print(f"User: {prompt}\nAssistant: {response.text}") print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - await router.shutdown.call() + await policy.shutdown.call() if __name__ == "__main__": @@ -450,4 +460,4 @@ async def _test(config, guided_decoding=False, num_samples=1): # asyncio.run(_test(config)) # asyncio.run(_test(config, guided_decoding=True)) # asyncio.run(_test(config, num_samples=2)) - # asyncio.run(_test(config, guided_decoding=True, num_samples=3)) + asyncio.run(_test(config, guided_decoding=True, num_samples=3)) From d4c83d35ea21e7a0020aef4e40b4653bebe79999 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 21 Aug 2025 10:10:28 -0700 Subject: [PATCH 02/15] Partial paths for sapwn_actor and spawn service --- src/forge/actors/policy.py | 131 ++++++++++++++++++++++--------------- src/forge/interfaces.py | 18 ++++- 2 files changed, 96 insertions(+), 53 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index c096a1657..74bb05b72 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,9 +13,16 @@ from typing import Dict, List import torch +from forge.controller import spawn_actors +from forge.controller.service import ServiceConfig +from forge.controller.spawn import spawn_service +from forge.data.sharding import VLLMSharding from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig + from monarch.actor import Actor, current_rank, endpoint, proc_mesh +from omegaconf import DictConfig, OmegaConf from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM @@ -40,27 +47,43 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.data.sharding import VLLMSharding - logger = logging.getLogger(__name__) @dataclass class Policy(PolicyInterface): # TODO: Add dp support + config: DictConfig + # Gets set up by setup worker policy_worker: Actor = None + sampling_params: SamplingParams = None lora_request: LoRARequest = None tokenization_kwargs: dict = None @endpoint - async def setup(self, config, guided_decoding=False, num_samples=1): - # Set up workers - await self.setupWorker(config, guided_decoding, num_samples) + async def setup(self): + # Set up policy_worker + await self.spawn_workers() self.request_id = 0 self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {} self.vllm_args = await self.policy_worker.get_vllm_args.choose() + + # Setup sampling params + sampling_overrides = self.config.sampling_params + overrides = { + "n": sampling_overrides.num_samples, + "guided_decoding": ( + GuidedDecodingParams(choice=["Positive", "Negative"]) + if sampling_overrides.guided_decoding + else None + ), + } + self.sampling_params = get_default_sampling_params( + self.vllm_args, overrides=overrides + ) + # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` @@ -74,7 +97,7 @@ async def setup(self, config, guided_decoding=False, num_samples=1): ) self.output_processor = OutputProcessor(tokenizer, log_stats=None) - # Setup schduuler + # Setup scheduler # TODO: Add support for `log_stats` kv_cache_configs = await self.policy_worker.setup_kv_cache.call() kv_cache_config = kv_cache_configs._values[0] @@ -90,32 +113,20 @@ async def setup(self, config, guided_decoding=False, num_samples=1): log_stats=None, ) - async def setupWorker(self, config, guided_decoding, num_samples): + def should_spawn_workers(self) -> bool: + return True + + async def spawn_workers(self): self.worker_mesh = await proc_mesh( - gpus=config["resources"], + gpus=self.config.num_workers, env={ "MASTER_ADDR": str(get_loopback_ip()), "MASTER_PORT": str(get_open_port()), }, ) self.policy_worker = await self.worker_mesh.spawn( - "policy_worker", PolicyWorker, **config - ) - - # TODO: Make this customizable from the config - vllm_args = await self.policy_worker.get_vllm_args.choose() - overrides = { - "n": num_samples, - "guided_decoding": ( - GuidedDecodingParams(choice=["Positive", "Negative"]) - if guided_decoding - else None - ), - } - self.sampling_params = get_default_sampling_params( - vllm_args, overrides=overrides + "policy_worker", PolicyWorker, **self.config.worker_params ) - await self.policy_worker.setup.call() @endpoint @@ -125,8 +136,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu # Wraps prompt into a dict prompt: Dict[str, str] = convert_input(prompt) - if self.sampling_params is None: - self.sampling_params = get_default_sampling_params(self.vllm_args) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} @@ -236,7 +245,6 @@ class PolicyWorker(Actor): pipeline_parallel_size: int = 1 enforce_eager: bool = False vllm_args: EngineArgs = None - resources: int = 1 state_dict_key: str = "model_state_dict" def __post_init__(self): @@ -279,7 +287,6 @@ def __post_init__(self): setattr(self.vllm_args, key, value) # Build Config self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) - assert self.vllm_args.parallel_config.world_size == self.resources @endpoint async def setup(self, store: MultiProcessStore = None): @@ -426,19 +433,32 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: return params -async def _test(config, guided_decoding=False, num_samples=1): - # TODO: Create proper test - policy_mesh = await proc_mesh(gpus=1) - policy = await policy_mesh.spawn("policy", Policy) +# TODO: Create proper test +async def _test(config: DictConfig): + prompt = ( + "What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke" + ) - print("Model setup") - await policy.setup.call(config, guided_decoding, num_samples) - - print("Model running") - policy.run.call() + with_service = False + if with_service: + service_config = ServiceConfig( + procs_per_replica=1, min_replicas=1, max_replicas=2, default_replicas=1 + ) + print("spawning service") + service = await spawn_service(service_config, Policy, config=config) + service.run() + responses: List[CompletionOutput] = await service.generate(prompt) - prompt = "What is 3+5?" if guided_decoding else "Tell me a joke" - responses: List[CompletionOutput] = await policy.generate.call_one(prompt) + else: + process_config = ProcessConfig() + policy = await spawn_actors( + name="policy", actor_cls=Policy, cfg=config, processes=process_config + ) + print("Model setup") + await policy.setup.call() + print("Model running") + policy.run.call() + responses: List[CompletionOutput] = await policy.generate.call_one(prompt) for batch, response in enumerate(responses): print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") @@ -446,18 +466,27 @@ async def _test(config, guided_decoding=False, num_samples=1): print(f"User: {prompt}\nAssistant: {response.text}") print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - await policy.shutdown.call() + if with_service: + await service.stop() + else: + await policy.shutdown.call() if __name__ == "__main__": - config = { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 1, - "enforce_eager": True, - "resources": 2, - } - # asyncio.run(_test(config)) - # asyncio.run(_test(config, guided_decoding=True)) - # asyncio.run(_test(config, num_samples=2)) - asyncio.run(_test(config, guided_decoding=True, num_samples=3)) + config = OmegaConf.create( + { + "num_workers": 2, + "worker_params": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 1, + "enforce_eager": True, + "vllm_args": None, + }, + "sampling_params": { + "guided_decoding": True, + "num_samples": 2, + }, + } + ) + asyncio.run(_test(config)) diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index f19f379cb..94eed7b80 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -7,10 +7,10 @@ from abc import ABC, abstractmethod from typing import Any -from monarch.actor import Actor, endpoint - from forge.types import Action, Message, Observation, State +from monarch.actor import Actor, endpoint + class Transform(ABC): """Abstract base class for observation transforms. @@ -87,6 +87,20 @@ async def update_weights(self): """Update the policy weights.""" pass + @abstractmethod + def should_spawn_workers(self) -> bool: + """Whether the policy needs to separately spawrn child workers.""" + pass + + @abstractmethod + def spawn_workers(self): + """ + Spawn child workers used by this actor + + No-op when should_spawn_workers() is False. + """ + pass + class BaseTokenizer(ABC): """ From 8d67786342a9e60b5e392322d00a8a3ca1433df2 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 21 Aug 2025 13:59:02 -0700 Subject: [PATCH 03/15] Commenting sections to update post Service Refactor --- src/forge/actors/policy.py | 6 ++---- src/forge/interfaces.py | 24 +++++++++++++----------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 74bb05b72..af5705c16 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -24,7 +24,6 @@ from monarch.actor import Actor, current_rank, endpoint, proc_mesh from omegaconf import DictConfig, OmegaConf from torchstore import MultiProcessStore - from torchstore._state_dict_utils import DELIM from vllm.engine.arg_utils import EngineArgs @@ -441,9 +440,8 @@ async def _test(config: DictConfig): with_service = False if with_service: - service_config = ServiceConfig( - procs_per_replica=1, min_replicas=1, max_replicas=2, default_replicas=1 - ) + # Update this condition path once Service has been refactored + service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) print("spawning service") service = await spawn_service(service_config, Policy, config=config) service.run() diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 94eed7b80..d780b829c 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -87,19 +87,21 @@ async def update_weights(self): """Update the policy weights.""" pass - @abstractmethod - def should_spawn_workers(self) -> bool: - """Whether the policy needs to separately spawrn child workers.""" - pass + # TODO: Update Based on Service Refactor - @abstractmethod - def spawn_workers(self): - """ - Spawn child workers used by this actor + # @abstractmethod + # def should_spawn_workers(self) -> bool: + # """Whether the policy needs to separately spawrn child workers.""" + # pass - No-op when should_spawn_workers() is False. - """ - pass + # @abstractmethod + # def spawn_workers(self): + # """ + # Spawn child workers used by this actor + + # No-op when should_spawn_workers() is False. + # """ + # pass class BaseTokenizer(ABC): From 5a3807ebbf208ac9ead94f472e2d8175886e2324 Mon Sep 17 00:00:00 2001 From: Danning XIE <24580222+DNXie@users.noreply.github.com> Date: Thu, 21 Aug 2025 18:08:15 -0700 Subject: [PATCH 04/15] make dataset configurable and add validation loop (#54) * make dataset configurable * add validation loop * update config * fix infinite loop, current_num_tokens * add single pass to the parameters * add single pass to param * fix the hang issue * minor: update error message * move batch_to_device to utils and add support to blockmask * remove comment * clean * fix validation backward thing for pp * remove self.model * add max_steps for validation to avoid hang * remove infinite --- apps/sft/llama3_8b.yaml | 14 ++++++- apps/sft/main.py | 86 +++++++++++++++++++++++++++++++++++------ src/forge/data/utils.py | 31 +++++++++++++++ 3 files changed, 119 insertions(+), 12 deletions(-) diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 573b401bd..37c24b69d 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -31,7 +31,19 @@ training: max_norm: 1.0 steps: 1000 compile: false - dataset: "c4" + +validation: + local_batch_size: 1 + freq: -1 # Change to a positive number to enable validation + steps: 200 # Max steps to run validation. Validation disabled if negative. + +dataset: + path: yahma/alpaca-cleaned + split: train[:95%] + +dataset_val: + path: yahma/alpaca-cleaned + split: train[95%:] parallelism: data_parallel_replicate_degree: 1 diff --git a/apps/sft/main.py b/apps/sft/main.py index 9781dad5c..b5ae6fc16 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -18,9 +18,11 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX from omegaconf import DictConfig, OmegaConf from torch import nn + from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.components.loss import LossFunction from torchtitan.components.lr_scheduler import LRSchedulersContainer @@ -30,6 +32,7 @@ from torchtitan.experiments.forge.job_config import ForgeJobConfig from tqdm import tqdm + # stubs for now Checkpointer = Any Dataloader = Any @@ -63,7 +66,16 @@ def __init__(self, job_config: ForgeJobConfig): self.metric_logger = None # TODO: fix this def setup(self): - self.train_dataloader = self.setup_data() + self.train_dataloader = self.setup_data( + self.job_config.dataset, + batch_size=self.job_config.training.local_batch_size, + ) + + self.val_dataloader = self.setup_data( + self.job_config.dataset_val, + batch_size=self.job_config.validation.local_batch_size, + ) + # self.train_dataloader = self.setup_data( # self.train_config.train_dataset_config, # self.train_config.train_dataloader_config, @@ -79,7 +91,7 @@ def setup(self): # self.profiler = self.setup_profiler(self.train_config.profiler_config) # self.logger = self.setup_logger(self.train_config.logger_config) - def setup_data(self): + def setup_data(self, dataset_config, batch_size): tokenizer = HuggingFaceModelTokenizer( tokenizer_json_path=os.path.join( self.job_config.model.hf_assets_path, "tokenizer.json" @@ -95,8 +107,8 @@ def setup_data(self): dataset = sft_iterable_dataset( model_transform=tokenizer, message_transform=AlpacaToMessages(), - path="yahma/alpaca-cleaned", - split="train", + path=dataset_config.path, + split=dataset_config.split, ) packer = TextPacker(padding_idx=0) dataset = PackedDataset( @@ -106,7 +118,7 @@ def setup_data(self): ) dataloader = StatefulDataLoader( dataset=dataset, - batch_size=self.job_config.training.local_batch_size, + batch_size=batch_size, collate_fn=partial( collate_packed, mask_fn=packer.create_block_mask, device=self.device ), @@ -119,7 +131,10 @@ def setup_data(self): return dataloader def forward_backward( - self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + do_backward: bool = True, ) -> torch.Tensor: model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -145,14 +160,16 @@ def forward_backward( targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) + if do_backward: + pp_schedule_fn = self.pp_schedule.step + else: + pp_schedule_fn = self.pp_schedule.eval if self.pp_has_first_stage: - self.pp_schedule.step( + pp_schedule_fn( inputs, target=targets, losses=losses, input_batch=inputs ) else: - self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs - ) + pp_schedule_fn(target=targets, losses=losses, input_batch=inputs) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU @@ -170,7 +187,8 @@ def forward_backward( loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred - loss.backward() + if do_backward: + loss.backward() return loss @@ -214,6 +232,52 @@ def train(self) -> None: last_step=self.current_step == self.num_training_steps, ) + if ( + self.job_config.validation.freq > 0 + and self.job_config.validation.steps > 0 + and self.current_step % self.job_config.validation.freq == 0 + ): + self.validate(self.job_config.validation.steps) + + def validate(self, max_steps: int) -> None: + for m in self.model_parts: + m.eval() + total_val_loss = torch.tensor(0.0, device=self.device) + total_val_tokens = torch.tensor(0.0, device=self.device) + with torch.no_grad(): + val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False) + for batch_idx, batch in enumerate(val_pbar): + if batch_idx >= max_steps: + break + batch_to_device(batch, self.device) + current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum() + # Compute loss + labels = batch.pop("labels") + loss = self.forward_backward(batch, labels, do_backward=False) + val_loss = loss * current_num_tokens + total_val_loss += val_loss + total_val_tokens += current_num_tokens + # Update progress bar description with current average loss + avg_loss_so_far = ( + (total_val_loss / total_val_tokens).item() + if total_val_tokens > 0 + else float("inf") + ) + val_pbar.set_description( + f"Running validation Loss: {avg_loss_so_far:.4f}" + ) + # Aggregate validation metrics across all ranks + torch.distributed.all_reduce(total_val_loss) + torch.distributed.all_reduce(total_val_tokens) + avg_val_loss = ( + (total_val_loss / total_val_tokens).item() + if total_val_tokens > 0 + else float("inf") + ) + for m in self.model_parts: + m.train() + print(f"\nValidation loss: {avg_val_loss}") + def cleanup(self) -> None: if self.checkpointer: self.checkpointer.close() diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index 4cc58f2d3..f6f4fa73c 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -7,6 +7,10 @@ from enum import Enum from typing import Any, Literal, Optional, Union +import torch + +from torch.nn.attention.flex_attention import BlockMask + CROSS_ENTROPY_IGNORE_IDX = -100 Role = Literal[ @@ -182,3 +186,30 @@ def mask_messages( message.masked = True elif masking_strategy == MaskingStrategy.TRAIN_ON_ASSISTANT: message.masked = message.role != "assistant" + + +def batch_to_device(batch: dict, device: torch.device) -> None: + """Function that takes a dictionary (or nested dictionary) of tensors and sets them + all to the same device. This utility is intended to be used for batches of data to be + moved to device, the update is inplace. + + Args: + batch (dict): dict of Tensors or more nested dicts of tensors. + device (torch.device): torch device to move the tensors to. + + Raises: + ValueError: if batch dict contains anything other than ``torch.Tensor``. + + """ + for k, v in batch.items(): + if isinstance(v, dict): + batch_to_device(v, device) + elif isinstance(v, torch.Tensor): + batch[k] = v.to(device) + elif isinstance(v, BlockMask): + batch[k] = v.to(device) + else: + raise ValueError( + f"""To use batch_to_device, all elements in the batch must be a dict, Tensor, or BlockMask with flexattention enabled. +Got key "{k}" with value of type {type(v)}""" + ) From db603299bb446a1a0cfc2029073ecd057cd03c9b Mon Sep 17 00:00:00 2001 From: Danning XIE <24580222+DNXie@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:15:57 -0700 Subject: [PATCH 05/15] Add math/thinking reward (#64) * Add reward interface, math reward, unit tests * move test files to rl folder * add thinking reward --- src/forge/data/rewards/math.py | 56 ++++++ src/forge/data/rewards/thinking.py | 19 ++ src/forge/interfaces.py | 13 +- tests/unit_tests/rl/test_math_reward.py | 202 ++++++++++++++++++++ tests/unit_tests/rl/test_thinking_reward.py | 144 ++++++++++++++ 5 files changed, 432 insertions(+), 2 deletions(-) create mode 100644 src/forge/data/rewards/math.py create mode 100644 src/forge/data/rewards/thinking.py create mode 100644 tests/unit_tests/rl/test_math_reward.py create mode 100644 tests/unit_tests/rl/test_thinking_reward.py diff --git a/src/forge/data/rewards/math.py b/src/forge/data/rewards/math.py new file mode 100644 index 000000000..06dd6cbc6 --- /dev/null +++ b/src/forge/data/rewards/math.py @@ -0,0 +1,56 @@ +import re +from typing import Optional + +from forge.interfaces import Reward + + +class MathReward(Reward): + """Reward class for evaluating math correctness.""" + + def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1): + self.tolerance = tolerance + self.partial_credit = partial_credit + + def _to_float(self, text) -> Optional[float]: + """Safely parse a string into a float, or return None if invalid.""" + if text is None: + return None + try: + return float(str(text).strip()) + except (ValueError, TypeError): + return None + + def _extract_number(self, text: str) -> Optional[float]: + """Try to extract a numeric answer from text.""" + number_pattern = r"([+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)" + patterns = [ + r"####\s*" + number_pattern, + r"(?:the\s+)?answer\s+is\s*" + number_pattern, + r"(?:answer:|result:)\s*" + number_pattern, + r"\$" + number_pattern, # currency + number_pattern, # fallback + r"=\s*" + number_pattern + r"\s*(?:\.|$)", + r"\b" + number_pattern + r"\s*(?:\.|$)", + ] + text = text.lower().strip() + for pattern in patterns: + matches = re.findall(pattern, text) + if matches: + return self._to_float(matches[-1]) + return None + + def __call__(self, prompt: str, response: str, target: str) -> float: + """Compute math correctness reward.""" + # Parse expected + expected_answer = self._to_float(target) + + # Parse response + model_answer = self._extract_number(response) + + # Scoring + if expected_answer is None or model_answer is None: + return self.partial_credit # Partial credit for attempting + + if abs(expected_answer - model_answer) < self.tolerance: + return 1.0 # Correct answer + return 0.0 # Incorrect answer diff --git a/src/forge/data/rewards/thinking.py b/src/forge/data/rewards/thinking.py new file mode 100644 index 000000000..8c4eb6852 --- /dev/null +++ b/src/forge/data/rewards/thinking.py @@ -0,0 +1,19 @@ +from typing import Optional + +from forge.interfaces import Reward + + +class ThinkingReward(Reward): + """Reward class for evaluating use of tags in reasoning.""" + + def __init__(self, reward_value: float = 0.5): + self.reward_value = reward_value + + def __call__( + self, prompt: str, response: str, target: Optional[str] = None + ) -> float: + """Check if response contains ... tags.""" + resp = response.lower() + if "" in resp and "" in resp: + return self.reward_value + return 0.0 diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index f19f379cb..b485fc791 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -7,10 +7,10 @@ from abc import ABC, abstractmethod from typing import Any -from monarch.actor import Actor, endpoint - from forge.types import Action, Message, Observation, State +from monarch.actor import Actor, endpoint + class Transform(ABC): """Abstract base class for observation transforms. @@ -150,3 +150,12 @@ def tokenize_messages( tuple[list[int], list[bool]]: The list of token ids and the list of masks. """ pass + + +class Reward(ABC): + """Abstract base class for reward models.""" + + @abstractmethod + def __call__(self, observation: Observation) -> float: + """Compute a reward for an observation.""" + pass diff --git a/tests/unit_tests/rl/test_math_reward.py b/tests/unit_tests/rl/test_math_reward.py new file mode 100644 index 000000000..a109492dd --- /dev/null +++ b/tests/unit_tests/rl/test_math_reward.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest import mock + +from forge.data.rewards.math import MathReward + + +class TestMathReward(unittest.TestCase): + def setUp(self): + """Set up test fixtures before each test method.""" + self.reward = MathReward() + self.custom_reward = MathReward(tolerance=1e-3, partial_credit=0.2) + + def test_init_default_values(self): + """Test MathReward initialization with default values.""" + reward = MathReward() + self.assertEqual(reward.tolerance, 1e-6) + self.assertEqual(reward.partial_credit, 0.1) + + def test_init_custom_values(self): + """Test MathReward initialization with custom values.""" + reward = MathReward(tolerance=1e-3, partial_credit=0.2) + self.assertEqual(reward.tolerance, 1e-3) + self.assertEqual(reward.partial_credit, 0.2) + + def test_to_float_valid_numbers(self): + """Test _to_float with valid numeric strings.""" + self.assertEqual(self.reward._to_float("42"), 42.0) + self.assertEqual(self.reward._to_float("3.14"), 3.14) + self.assertEqual(self.reward._to_float("-5.5"), -5.5) + self.assertEqual(self.reward._to_float("0"), 0.0) + self.assertEqual(self.reward._to_float(" 123.45 "), 123.45) + + def test_to_float_invalid_inputs(self): + """Test _to_float with invalid inputs.""" + self.assertIsNone(self.reward._to_float("abc")) + self.assertIsNone(self.reward._to_float("")) + self.assertIsNone(self.reward._to_float("12.34.56")) + self.assertIsNone(self.reward._to_float("not a number")) + self.assertIsNone(self.reward._to_float(None)) + + def test_to_float_edge_cases(self): + """Test _to_float with edge cases.""" + self.assertEqual(self.reward._to_float("1e6"), 1000000.0) + self.assertEqual(self.reward._to_float("-1.5e-3"), -0.0015) + self.assertEqual(self.reward._to_float("inf"), float("inf")) + self.assertEqual(self.reward._to_float("-inf"), float("-inf")) + + def test_extract_number_gsm8k_format(self): + """Test _extract_number with GSM8K style format.""" + self.assertEqual(self.reward._extract_number("#### 42"), 42.0) + self.assertEqual(self.reward._extract_number("#### -3.14"), -3.14) + self.assertEqual(self.reward._extract_number("Some text #### 123.45"), 123.45) + + def test_extract_number_answer_patterns(self): + """Test _extract_number with various answer patterns.""" + self.assertEqual(self.reward._extract_number("The answer is 42"), 42.0) + self.assertEqual(self.reward._extract_number("answer is 3.14"), 3.14) + self.assertEqual(self.reward._extract_number("Answer: 123"), 123.0) + self.assertEqual(self.reward._extract_number("Result: -5.5"), -5.5) + + def test_extract_number_equals_pattern(self): + """Test _extract_number with equals sign patterns.""" + self.assertEqual(self.reward._extract_number("x = 42."), 42.0) + self.assertEqual(self.reward._extract_number("The result = 3.14"), 3.14) + self.assertEqual(self.reward._extract_number("calculation = -7.5."), -7.5) + + def test_extract_number_end_of_text(self): + """Test _extract_number with numbers at end of text.""" + self.assertEqual(self.reward._extract_number("The final result is 42."), 42.0) + self.assertEqual(self.reward._extract_number("We get 3.14"), 3.14) + self.assertEqual(self.reward._extract_number("Answer: -5.5."), -5.5) + + def test_extract_number_fallback_pattern(self): + """Test _extract_number with fallback pattern (any number).""" + self.assertEqual(self.reward._extract_number("There are 42 items"), 42.0) + self.assertEqual(self.reward._extract_number("Cost is $3.14 per item"), 3.14) + self.assertEqual(self.reward._extract_number("Temperature: -5.5 degrees"), -5.5) + + def test_extract_number_multiple_matches(self): + """Test _extract_number returns the last match when multiple numbers exist.""" + # Should return the last match from the pattern + self.assertEqual( + self.reward._extract_number("First 10, then 20, finally 30"), 30.0 + ) + self.assertEqual( + self.reward._extract_number("#### 5 but actually #### 10"), 10.0 + ) + + def test_extract_number_no_match(self): + """Test _extract_number when no numbers are found.""" + self.assertIsNone(self.reward._extract_number("No numbers here")) + self.assertIsNone(self.reward._extract_number("")) + self.assertIsNone(self.reward._extract_number("Just text")) + + def test_extract_number_case_insensitive(self): + """Test _extract_number is case insensitive.""" + self.assertEqual(self.reward._extract_number("THE ANSWER IS 42"), 42.0) + self.assertEqual(self.reward._extract_number("Answer: 3.14"), 3.14) + self.assertEqual(self.reward._extract_number("RESULT: 123"), 123.0) + + def test_call_correct_answer(self): + """Test __call__ with correct answers.""" + self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 1.0) + self.assertEqual(self.reward("prompt", "#### 3.14", "3.14"), 1.0) + self.assertEqual(self.reward("prompt", "Result: -5.5", "-5.5"), 1.0) + + def test_call_within_tolerance(self): + """Test __call__ with answers within tolerance.""" + # Default tolerance is 1e-6 + self.assertEqual(self.reward("prompt", "42.0000001", "42"), 1.0) + self.assertEqual(self.reward("prompt", "3.1400001", "3.14"), 1.0) + + # Custom tolerance + self.assertEqual(self.custom_reward("prompt", "42.0001", "42"), 1.0) + self.assertEqual(self.custom_reward("prompt", "3.141", "3.14"), 1.0) + + def test_call_outside_tolerance(self): + """Test __call__ with answers outside tolerance.""" + self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0) + self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0) + self.assertEqual(self.custom_reward("prompt", "42.01", "42"), 0.0) + + def test_call_invalid_target(self): + """Test __call__ with invalid target values.""" + self.assertEqual( + self.reward("prompt", "42", "invalid"), self.reward.partial_credit + ) + self.assertEqual(self.reward("prompt", "42", ""), self.reward.partial_credit) + self.assertEqual( + self.reward("prompt", "42", "not a number"), self.reward.partial_credit + ) + + def test_call_invalid_response(self): + """Test __call__ with invalid response values.""" + self.assertEqual( + self.reward("prompt", "no number", "42"), self.reward.partial_credit + ) + self.assertEqual(self.reward("prompt", "", "42"), self.reward.partial_credit) + self.assertEqual( + self.reward("prompt", "just text", "42"), self.reward.partial_credit + ) + + def test_call_both_invalid(self): + """Test __call__ with both invalid target and response.""" + self.assertEqual( + self.reward("prompt", "no number", "invalid"), self.reward.partial_credit + ) + self.assertEqual(self.reward("prompt", "", ""), self.reward.partial_credit) + + def test_call_custom_partial_credit(self): + """Test __call__ uses custom partial credit value.""" + self.assertEqual(self.custom_reward("prompt", "no number", "42"), 0.2) + self.assertEqual(self.custom_reward("prompt", "42", "invalid"), 0.2) + + def test_call_zero_values(self): + """Test __call__ with zero values.""" + self.assertEqual(self.reward("prompt", "0", "0"), 1.0) + self.assertEqual(self.reward("prompt", "The answer is 0", "0.0"), 1.0) + + def test_call_negative_values(self): + """Test __call__ with negative values.""" + self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0) + self.assertEqual(self.reward("prompt", "#### -3.14", "-3.14"), 1.0) + self.assertEqual(self.reward("prompt", "-5", "-4.9"), 0.0) + + def test_call_large_numbers(self): + """Test __call__ with large numbers.""" + self.assertEqual(self.reward("prompt", "1000000", "1000000"), 1.0) + self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0) + self.assertEqual(self.reward("prompt", "1000001", "1000000"), 0.0) + + def test_call_small_numbers(self): + """Test __call__ with very small numbers.""" + self.assertEqual(self.reward("prompt", "0.000001", "0.000001"), 1.0) + self.assertEqual(self.reward("prompt", "1e-6", "0.000001"), 1.0) + + def test_call_complex_response_text(self): + """Test __call__ with complex response text containing multiple elements.""" + response = """ + Let me solve this step by step: + First, I calculate 2 + 3 = 5 + Then, I multiply by 4: 5 * 4 = 20 + Finally, I subtract 8: 20 - 8 = 12 + #### 12 + """ + self.assertEqual(self.reward("prompt", response, "12"), 1.0) + + def test_call_with_units_and_formatting(self): + """Test __call__ with responses containing units and formatting.""" + self.assertEqual(self.reward("prompt", "The cost is $42.50", "42.5"), 1.0) + self.assertEqual(self.reward("prompt", "Distance: 3.14 meters", "3.14"), 1.0) + self.assertEqual(self.reward("prompt", "Temperature is -5.5°C", "-5.5"), 1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/rl/test_thinking_reward.py new file mode 100644 index 000000000..d4af16969 --- /dev/null +++ b/tests/unit_tests/rl/test_thinking_reward.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from forge.data.rewards.thinking import ThinkingReward + + +class TestThinkingReward(unittest.TestCase): + def setUp(self): + """Set up test fixtures before each test method.""" + self.reward = ThinkingReward() + self.custom_reward = ThinkingReward(reward_value=0.8) + + def test_init_default_values(self): + """Test ThinkingReward initialization with default values.""" + reward = ThinkingReward() + self.assertEqual(reward.reward_value, 0.5) + + def test_init_custom_values(self): + """Test ThinkingReward initialization with custom values.""" + reward = ThinkingReward(reward_value=0.8) + self.assertEqual(reward.reward_value, 0.8) + + def test_call_with_both_tags(self): + """Test __call__ with response containing both and tags.""" + response = "This is my reasoning" + result = self.reward("prompt", response) + self.assertEqual(result, 0.5) + + result = self.custom_reward("prompt", response) + self.assertEqual(result, 0.8) + + def test_call_with_both_tags_complex_content(self): + """Test __call__ with complex content between thinking tags.""" + response = """ + Let me solve this problem step by step. + + First, I need to understand what the question is asking. + Then I'll work through the calculation: + 2 + 2 = 4 + So the answer should be 4. + + The answer is 4. + """ + result = self.reward("prompt", response) + self.assertEqual(result, 0.5) + + def test_call_with_only_opening_tag(self): + """Test __call__ with response containing only tag.""" + response = "This is incomplete reasoning" + result = self.reward("prompt", response) + self.assertEqual(result, 0.0) + + def test_call_with_only_closing_tag(self): + """Test __call__ with response containing only tag.""" + response = "This is incomplete reasoning" + result = self.reward("prompt", response) + self.assertEqual(result, 0.0) + + def test_call_with_no_tags(self): + """Test __call__ with response containing no thinking tags.""" + response = "This is just a regular response without any thinking tags." + result = self.reward("prompt", response) + self.assertEqual(result, 0.0) + + def test_call_case_insensitive(self): + """Test __call__ is case insensitive for thinking tags.""" + # Mixed case tags should work + response = "This is my reasoning" + result = self.reward("prompt", response) + self.assertEqual(result, 0.5) + + response = "This is my reasoning" + result = self.reward("prompt", response) + self.assertEqual(result, 0.5) + + response = "This is my reasoning" + result = self.reward("prompt", response) + self.assertEqual(result, 0.5) + + def test_call_multiple_thinking_blocks(self): + """Test __call__ with multiple thinking blocks.""" + response = """ + First thought + Some text in between. + Second thought + """ + result = self.reward("prompt", response) + self.assertEqual(result, 0.5) + + def test_call_nested_tags(self): + """Test __call__ with nested or malformed tags.""" + # Nested tags - should still work as long as both tags exist + response = "Outer inner thought" + result = self.reward("prompt", response) + self.assertEqual(result, 0.5) + + def test_call_empty_thinking_block(self): + """Test __call__ with empty thinking block.""" + response = "" + result = self.reward("prompt", response) + self.assertEqual(result, 0.5) + + def test_call_empty_response(self): + """Test __call__ with empty response.""" + result = self.reward("prompt", "") + self.assertEqual(result, 0.0) + + def test_call_tags_with_extra_whitespace(self): + """Test __call__ with thinking tags containing extra whitespace.""" + response = "< think >This has spaces< /think >" + result = self.reward("prompt", response) + self.assertEqual(result, 0.0) # Should not match due to spaces in tags + + def test_call_with_target_parameter(self): + """Test __call__ with target parameter (should be ignored).""" + response = "This is my reasoning" + result = self.reward("prompt", response, target="some target") + self.assertEqual(result, 0.5) + + result = self.reward("prompt", "no tags", target="some target") + self.assertEqual(result, 0.0) + + def test_call_zero_reward_value(self): + """Test __call__ with zero reward value.""" + zero_reward = ThinkingReward(reward_value=0.0) + response = "This is my reasoning" + result = zero_reward("prompt", response) + self.assertEqual(result, 0.0) + + def test_call_negative_reward_value(self): + """Test __call__ with negative reward value.""" + negative_reward = ThinkingReward(reward_value=-0.5) + response = "This is my reasoning" + result = negative_reward("prompt", response) + self.assertEqual(result, -0.5) + + +if __name__ == "__main__": + unittest.main() From 0681d7ad2324a005d8cb94569241051e945ab360 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Mon, 25 Aug 2025 15:15:21 -0500 Subject: [PATCH 06/15] [Service] Refactors service to align with EX545081 (#65) * initial commit for replica * clean up * phase out service for service v2 * remove v2 * remove v2 from spawn * more minor cleanups * remove comment * remove comment * simplify and unify replica initialization * address comments * address comments * add capacity semaphore * f-strings * remove redundant health set --------- Co-authored-by: Allen Wang --- src/forge/controller/__init__.py | 2 - src/forge/controller/recoverable_mesh.py | 289 ------------ src/forge/controller/replica.py | 512 ++++++++++++++++++++ src/forge/controller/service.py | 574 +++++------------------ tests/test_service.py | 80 +++- 5 files changed, 711 insertions(+), 746 deletions(-) delete mode 100644 src/forge/controller/recoverable_mesh.py create mode 100644 src/forge/controller/replica.py diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index a191d931a..a800eb14c 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. from .actor import ForgeActor from .proc_mesh import get_proc_mesh, spawn_actors -from .recoverable_mesh import RecoverableProcMesh from .service import Service, ServiceConfig from .spawn import spawn_service @@ -16,5 +15,4 @@ "spawn_actors", "get_proc_mesh", "ForgeActor", - "RecoverableProcMesh", ] diff --git a/src/forge/controller/recoverable_mesh.py b/src/forge/controller/recoverable_mesh.py deleted file mode 100644 index d352eab17..000000000 --- a/src/forge/controller/recoverable_mesh.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Recoverable Process Mesh - -This module provides a fault-tolerant wrapper around ProcMesh that automatically -recovers from crashes and failures. The RecoverableProcMesh class maintains the -same API as ProcMesh while adding automatic recovery capabilities. - -Key Features: -- **Automatic Recovery**: Detects mesh failures and automatically respawns processes -- **State Management**: Tracks mesh health and recovery status -- **Graceful Degradation**: Handles failures without losing the entire service -- **Context Management**: Supports async context manager for resource cleanup -- **Actor Respawning**: Automatically respawns actors after mesh recovery - -Example: - Basic usage with automatic recovery: - - >>> mesh = RecoverableProcMesh(num_gpus=2) - >>> - >>> async def spawn_actor(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass, *args) - ... return actor - >>> - >>> await mesh.spawn(spawn_actor) - >>> # Mesh will automatically recover if it fails - - Context manager usage: - - >>> async with RecoverableProcMesh(num_gpus=1) as mesh: - ... await mesh.spawn(spawn_actor) - ... # Mesh automatically cleaned up on exit -""" - -import asyncio -import logging -from enum import Enum -from typing import Any, Callable, Coroutine, Optional, TypeVar - -from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice -from monarch._src.actor.actor_mesh import Actor -from monarch._src.actor.proc_mesh import ProcMesh -from monarch._src.actor.shape import MeshTrait - -from forge.controller.proc_mesh import get_proc_mesh -from forge.types import ProcessConfig - -T = TypeVar("T", bound=Actor) -logger: logging.Logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class MeshState(Enum): - """ - Enumeration of possible mesh states for tracking recovery status. - - States: - HEALTHY: Mesh is operational and ready to handle requests - RECOVERING: Mesh is in the process of recovering from a failure - UNHEALTHY: Mesh has failed and needs recovery - STOPPED: Mesh has been explicitly stopped and cannot be used - """ - - HEALTHY = 0 - RECOVERING = 1 - UNHEALTHY = 2 - STOPPED = 3 - - -class RecoverableProcMesh(MeshTrait): - """ - A fault-tolerant wrapper around ProcMesh with automatic crash recovery. - - This class provides the same API as ProcMesh while adding robust failure detection - and automatic recovery capabilities. When the underlying mesh crashes or becomes - unresponsive, it automatically creates a new mesh and respawns all actors. - - The RecoverableProcMesh maintains state tracking to ensure proper recovery sequencing - and prevents resource leaks during failure scenarios. It's designed for long-running - services that need high availability. - - Args: - proc_config: ProcessConfig containing mesh configuration including num_procs - - Attributes: - num_procs: Number of processes allocated to this mesh - state: Current state of the mesh (HEALTHY, RECOVERING, UNHEALTHY, STOPPED) - healthy: True if the mesh is operational and ready for requests - failed: True if the mesh has failed and needs recovery - - Example: - Basic usage with automatic recovery: - - >>> proc_config = ProcessConfig(num_procs=2, scheduler="local") - >>> mesh = RecoverableProcMesh(proc_config) - >>> - >>> async def setup_actor(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass) - ... await actor.initialize.call() - >>> - >>> await mesh.spawn(setup_actor) - >>> # If mesh fails, it will automatically recover and re-run setup_actor - - Context manager for automatic cleanup: - - >>> proc_config = ProcessConfig(num_procs=1) - >>> async with RecoverableProcMesh(proc_config) as mesh: - ... await mesh.spawn(setup_actor) - ... # Use mesh for operations - ... # Mesh automatically stopped and cleaned up on exit - - Manual state checking: - - >>> if mesh.healthy: - ... # Safe to use mesh - ... pass - >>> elif mesh.failed: - ... # Mesh needs recovery - ... await mesh.spawn(setup_actor) # Triggers recovery - """ - - def __init__( - self, - proc_config: ProcessConfig, - ) -> None: - self._proc_config: ProcessConfig = proc_config - self.num_procs = proc_config.num_procs - self._proc_mesh: Optional[ProcMesh] = None - self._recovery_task: Optional[asyncio.Task[None]] = None - self.state: MeshState = MeshState.UNHEALTHY - - async def spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - """ - Spawn actors on the mesh with automatic recovery. - - This method ensures the mesh is healthy before spawning actors. If the mesh - has failed, it automatically triggers recovery and then executes the spawn hook. - The hook function receives the underlying ProcMesh and should handle actor - creation and initialization. - - Args: - hook: Async function that receives a ProcMesh and spawns/initializes actors - - Example: - >>> async def setup_actors(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass) - ... await actor.setup.call() - >>> - >>> await mesh.spawn(setup_actors) - """ - await self._background_spawn(hook) - - def trigger_spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - self._background_spawn(hook) - - def _background_spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> asyncio.Task[None]: - if self.state == MeshState.STOPPED: - logger.warning("ProcMesh was already stopped when trying to spawn") - - self.state = MeshState.RECOVERING - self._recovery_task = asyncio.create_task(self._recover(hook)) - - return self._recovery_task - - def gpus(self) -> int: - return self.num_procs - - async def _recover( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - self.state = MeshState.RECOVERING - - old_proc_mesh = self._proc_mesh - self._proc_mesh = None - - if old_proc_mesh is not None: - try: - await old_proc_mesh.stop() - except Exception as e: - logger.warning(f"Error stopping old ProcMesh: {e}") - - try: - self._proc_mesh = await get_proc_mesh(process_config=self._proc_config) - if self._proc_mesh is not None: - await hook(self._proc_mesh) - self.state = MeshState.HEALTHY - - except Exception as e: - logger.exception(f"Recovery attempt failed: {e}") - self.state = MeshState.UNHEALTHY - - @property - def healthy(self) -> bool: - return self.state == MeshState.HEALTHY - - @property - def failed(self) -> bool: - return self.state == MeshState.UNHEALTHY - - async def stop(self) -> None: - """ - Stop the mesh and clean up all resources. - - Gracefully shuts down the underlying ProcMesh and marks this recoverable - mesh as stopped. Once stopped, the mesh cannot be used for further operations. - - This method is idempotent - calling it multiple times is safe. - - Example: - >>> await mesh.stop() - >>> # Mesh is now stopped and cannot be used - """ - logger.info("Stopping RecoverableProcMesh") - if self.state == MeshState.STOPPED: - logger.info("RecoverableProcMesh was already stopped") - return - try: - if self._proc_mesh is not None: - await self._proc_mesh.stop() - except RuntimeError as e: - logger.warning("RecoverableProcMesh could not be stopped: %s", e) - - self.state = MeshState.STOPPED - - async def __aenter__(self) -> "RecoverableProcMesh": - """Enter the async context manager.""" - if self.state == MeshState.STOPPED: - raise RuntimeError("RecoverableProcMesh has already been stopped") - return self - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - """Exit the async context manager.""" - # In case there are multiple nested "async with" statements, we only - # want it to close once. - if self.state != MeshState.STOPPED: - await self.stop() - - def mark_failed(self): - """ - Mark the mesh as failed, triggering recovery on next spawn. - - This method is typically called when an operation on the mesh fails - or when external monitoring detects that the mesh is unresponsive. - The next call to spawn() will trigger automatic recovery. - - Example: - >>> try: - ... # Some operation that might fail - ... await actor.some_method.call() - >>> except Exception: - ... mesh.mark_failed() # Mark for recovery - """ - self.state = MeshState.UNHEALTHY - - @property - def _shape(self) -> Shape: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._shape - - @property - def _ndslice(self) -> Slice: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._ndslice - - @property - def _labels(self) -> list[str]: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._labels - - def _new_with_shape(self, shape: Shape) -> "RecoverableProcMesh": - raise NotImplementedError( - "RecoverableProcMesh does not support _new_with_shape" - ) diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py new file mode 100644 index 000000000..7403bcb31 --- /dev/null +++ b/src/forge/controller/replica.py @@ -0,0 +1,512 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Replica for distributed actor service.""" + +import asyncio +import logging +import time +from collections import deque +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +from monarch.actor import Actor, ActorError, ProcMesh + +from forge.controller import get_proc_mesh +from forge.types import ProcessConfig + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class ReplicaState(Enum): + HEALTHY = "HEALTHY" + RECOVERING = "RECOVERING" + UNHEALTHY = "UNHEALTHY" + STOPPED = "STOPPED" + UNINITIALIZED = "UNINITIALIZED" + + +@dataclass +class ReplicaMetrics: + """Simple metrics tracking for a replica.""" + + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + request_times: deque = field(default_factory=lambda: deque(maxlen=100)) + request_latencies: deque = field(default_factory=lambda: deque(maxlen=100)) + + def add_request_start(self, timestamp: float): + """Records when a request starts processing.""" + self.request_times.append(timestamp) + self.total_requests += 1 + + def add_request_completion(self, start_time: float, success: bool): + """Records when a request completes.""" + latency = time.time() - start_time + self.request_latencies.append(latency) + if success: + self.successful_requests += 1 + else: + self.failed_requests += 1 + + def get_request_rate(self, window_seconds: float = 60.0) -> float: + """Gets requests per second over the last window_seconds.""" + assert window_seconds > 0, "Window must be positive" + now = time.time() + cutoff = now - window_seconds + recent_requests = [t for t in self.request_times if t >= cutoff] + return len(recent_requests) / window_seconds + + def get_avg_latency(self, window_requests: int = 50) -> float: + """Gets average latency over the last N requests.""" + if not self.request_latencies: + return 0.0 + recent_latencies = list(self.request_latencies)[-window_requests:] + return sum(recent_latencies) / len(recent_latencies) + + +@dataclass +class ServiceRequest: + """Representation of a request to the service. + + A service request will typically be a call to an actor endpoint. + - The endpoint call is represented by function str/args/kwargs, + - The session_id is used for stateful routing, and + - The future is used to return the result of the call. + + """ + + session_id: Optional[str] + function: str + args: tuple + kwargs: dict + future: asyncio.Future + + +@dataclass +class Replica: + """ + A distributed replica that serves as the fundamental unit of work within a service. + + Handles process lifecycle, async request queuing and fault recovery. + Each replica runs independently and can be deployed across multiple hosts via Monarch + + """ + + idx: int + + # Configuration for the underlying ProcMesh (scheduler, hosts, GPUs) + proc_config: ProcessConfig + actor_def: type[Actor] + actor_args: tuple + actor_kwargs: dict + + # The proc_mesh and actor_mesh that this replica is running + proc_mesh: Optional[ProcMesh] = None + actor: Optional[Actor] = None + + # Async queue for incoming requests + request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue) + # Number of currently processing requests + active_requests: int = 0 + # Maximum number of simultaneous requests + max_concurrent_requests: int = 10 + # Semaphore to control request capacity + _capacity_semaphore: asyncio.Semaphore = field(init=False) + # Whether the processing loop is currently running + _running: bool = False + # How often to check for new requests when idle + _run_poll_rate_s: float = 1.0 + # Current replica health state + state: ReplicaState = ReplicaState.UNINITIALIZED + # Whether to auto-unwrap ValueMesh to first rank + return_first_rank_result: bool = False + + # Recovery-related state + _recovery_task: Optional[asyncio.Task] = None + + # Run task is the replica's event loop + _run_task: Optional[asyncio.Task] = None + + # Metrics tracking + metrics: ReplicaMetrics = field(default_factory=ReplicaMetrics) + + def __post_init__(self): + # This semaphore is used to enforce max_concurrent_requests + # Once it is acquired max_concurrent_requests times, future + # requests are blocked until standing requests complete. + self._capacity_semaphore = asyncio.Semaphore(self.max_concurrent_requests) + + # Initialization related functionalities + + async def initialize(self): + """ + Initializes the replica completely from proc_mesh creation to ready state. + + This method handles the complete replica initialization process: + - Creates the proc_mesh + - Spawns the actor + - Configures the actor + - Transitions to healthy state + - Starts the processing loop + """ + assert self.proc_mesh is None, "Proc mesh should not be set yet" + try: + # Create proc_mesh + await self.create_proc_mesh() + + # Ensure we have a healthy proc_mesh + if not self.proc_mesh: + raise RuntimeError( + f"Replica {self.idx}: proc_mesh is None after creation" + ) + + # Spawn the actor + await self.spawn_actor( + actor_def=self.actor_def, + *self.actor_args, + **self.actor_kwargs, + ) + # Transition to healthy state and start processing + self.state = ReplicaState.HEALTHY + self.start_processing() + + logger.debug(f"Replica {self.idx} initialization complete") + + except Exception as e: + logger.error(f"Failed to initialize replica {self.idx}: {e}") + self.state = ReplicaState.UNHEALTHY + raise + + async def recover(self): + """Recovers the replica by recreating the proc_mesh and respawning actors.""" + if self._recovery_task and not self._recovery_task.done(): + # Recovery already in progress, wait for it + await self._recovery_task + return + + async def _do_recovery(): + old_proc_mesh = self.proc_mesh + self.proc_mesh = None + self.actor = None + + # Stop old proc_mesh if it exists + if old_proc_mesh is not None: + try: + await old_proc_mesh.stop() + logger.debug("Old proc_mesh stopped for replica %d", self.idx) + except Exception as e: + logger.warning( + "Error stopping old proc_mesh for replica %d: %s", self.idx, e + ) + + try: + logger.debug("Creating new proc_mesh for replica %d", self.idx) + await self.initialize() + logger.debug("Recovery completed successfully for replica %d", self.idx) + except Exception as e: + logger.error("Recovery failed for replica %d: %s", self.idx, e) + self.state = ReplicaState.UNHEALTHY + raise + + logger.debug("Starting recovery for replica %d", self.idx) + self._recovery_task = asyncio.create_task(_do_recovery()) + await self._recovery_task + + async def create_proc_mesh(self): + """Creates the proc_mesh using the stored proc_config.""" + # TODO - for policy replica, we would override this method to + # include multiple proc_meshes + if self.proc_mesh is not None: + logger.warning("Proc mesh already initialized for replica %d", self.idx) + return + + logger.debug("Creating proc_mesh for replica %d", self.idx) + try: + self.proc_mesh = await get_proc_mesh(process_config=self.proc_config) + logger.debug("Proc mesh created successfully for replica %d", self.idx) + except Exception as e: + logger.error("Failed to create proc_mesh for replica %d: %s", self.idx, e) + self.state = ReplicaState.UNHEALTHY + raise + + async def spawn_actor(self, actor_def, *actor_args, **actor_kwargs): + """ + Spawn an actor on this replica's proc_mesh. + + This method handles the complete actor spawning process including + recovery if the proc_mesh has failed. + """ + if not self.proc_mesh: + raise RuntimeError( + f"Replica {self.idx}: proc_mesh is None after recovery attempt" + ) + + try: + # TODO - expand support so name can stick within kwargs + actor_name = actor_kwargs.pop("name", actor_def.__name__) + + # Spawn the actor + self.actor = await self.proc_mesh.spawn( + actor_name, + actor_def, + *actor_args, + **actor_kwargs, + ) + # Call actor setup if it exists + if setup_method := getattr(self.actor, "setup", None): + await setup_method.call() + + logger.debug("Actor spawned successfully on replica %d", self.idx) + + except Exception as e: + logger.error("Failed to spawn actor on replica %d: %s", self.idx, e) + self.mark_failed() + raise + + # Request handling / processing related functionality + + def start_processing(self): + """Start the replica's processing loop if not already running.""" + if self._run_task is None or self._run_task.done(): + self._run_task = asyncio.create_task(self.run()) + logger.debug("Started processing loop for replica %d", self.idx) + + async def enqueue_request(self, request: ServiceRequest): + """Enqueues a request for processing by this replica.""" + if self.stopped: + raise RuntimeError( + f"Replica {self.idx} is stopped and therefore will not accept requests." + ) + + # Accept requests in all other states - let the processing loop handle the rest + await self.request_queue.put(request) + + async def _process_single_request(self, request: ServiceRequest) -> bool: + """Processes a single request and returns success status. + + Returns: + bool: True if request succeeded, False if it failed + """ + start_time = time.time() + self.active_requests += 1 + + # Record request start for metrics + self.metrics.add_request_start(start_time) + + try: + # Get the actor and endpoint + actor = self.actor + endpoint_func = getattr(actor, request.function) + + # Execute the request + success = True + try: + result = await endpoint_func.call(*request.args, **request.kwargs) + # Unwrap ValueMesh if configured to return first rank result + if ( + self.return_first_rank_result + and hasattr(result, "_values") + and result._values + ): + result = result._values[0] + request.future.set_result(result) + except ActorError as e: + logger.warning("Got failure on replica %d. Error:\n%s", self.idx, e) + # The exception came from the actor. It itself is + # returned to be propagated through the services + # back to the caller. + request.future.set_result(e.exception) + + # TODO: we may want to conditionally mark the + # replica as failed here - i.e. where the actor itself + # can be healthy but the request failed. + self.mark_failed() + success = False + except Exception as e: + logger.debug( + "Got unexpected error on replica %d. Error:\n%s", self.idx, e + ) + self.mark_failed() + + # The exception was not from the actor - in this case + # we will signal back to the service (through set_exception) + # to retry on another healthy node. + request.future.set_exception(e) + success = False + + self.metrics.add_request_completion(start_time, success) + # Mark task as done + self.request_queue.task_done() + return success + + finally: + self.active_requests -= 1 + # Release the capacity semaphore to allow new requests + self._capacity_semaphore.release() + + async def run(self): + """Runs the main processing loop for the replica. + + Continuously processes requests from the queue while the replica is healthy. + Handles capacity management and graceful degradation on failures. + """ + self._running = True + + try: + while self.healthy: + try: + # Wait for a request with timeout to check health periodically + request = await asyncio.wait_for( + self.request_queue.get(), timeout=self._run_poll_rate_s + ) + + # Acquire capacity semaphore - this blocks until capacity is available + await self._capacity_semaphore.acquire() + + # Process the request (semaphore will be released in _process_single_request) + asyncio.create_task(self._process_single_request(request)) + + except asyncio.TimeoutError: + # No requests, just continue checking for new ones + continue + + except Exception as e: + logger.error( + "Error in replica %d processing loop: %s", + self.idx, + e, + ) + self.state = ReplicaState.UNHEALTHY + break + + finally: + self._running = False + logger.debug("Replica %d stopped processing", self.idx) + + # Replica state management + + @property + def healthy(self) -> bool: + return self.state == ReplicaState.HEALTHY + + @property + def uninitialized(self) -> bool: + return self.state == ReplicaState.UNINITIALIZED + + @property + def recovering(self) -> bool: + return self.state == ReplicaState.RECOVERING + + @property + def unhealthy(self) -> bool: + return self.state == ReplicaState.UNHEALTHY + + @property + def stopped(self) -> bool: + return self.state == ReplicaState.STOPPED + + @property + def failed(self) -> bool: + """Check if the replica has failed and needs recovery.""" + return self.state in (ReplicaState.RECOVERING, ReplicaState.UNHEALTHY) + + def mark_failed(self): + """Mark the replica as failed, triggering recovery.""" + logger.debug("Marking replica %d as failed", self.idx) + self.state = ReplicaState.RECOVERING + + async def stop(self): + """ + Stops the replica gracefully. + + Transitions to STOPPED state, stops the processing loop, and cleans up. + Fails any remaining requests in the queue. + """ + logger.debug("Stopping replica %d", self.idx) + + # Transition to stopped state to signal the run loop to exit + self.state = ReplicaState.STOPPED + + if self._run_task and not self._run_task.done(): + self._run_task.cancel() + try: + await asyncio.wait_for( + self._run_task, timeout=2 * self._run_poll_rate_s + ) + except (asyncio.CancelledError, asyncio.TimeoutError): + # Expected - task was cancelled or timed out + pass + except Exception as e: + logger.warning("Unexpected error while stopping run task: %s", e) + + # Fail any remaining requests in the queue + failed_requests = [] + while not self.request_queue.empty(): + try: + request = self.request_queue.get_nowait() + failed_requests.append(request) + self.request_queue.task_done() + except asyncio.QueueEmpty: + # catching in case the queue became empty + # between check and get + break + + # Fail all the collected requests + for request in failed_requests: + if not request.future.done(): + request.future.set_exception( + RuntimeError(f"Replica {self.idx} is stopping") + ) + + logger.debug( + "Replica %d stopped, failed %d remaining requests", + self.idx, + len(failed_requests), + ) + + # Stop the proc_mesh + if self.proc_mesh: + try: + await self.proc_mesh.stop() + except Exception as e: + logger.warning( + "Error stopping proc_mesh for replica %d: %s", self.idx, e + ) + + # Metric-related getters + + @property + def current_load(self) -> int: + """Get current load (active requests + queue depth)""" + return self.active_requests + self.request_queue.qsize() + + def qsize(self) -> int: + """Get current queue size""" + return self.request_queue.qsize() + + @property + def capacity_utilization(self) -> float: + """Get current capacity utilization (0.0 to 1.0)""" + if self.max_concurrent_requests <= 0: + return 0.0 + return self.active_requests / self.max_concurrent_requests + + def can_accept_request(self) -> bool: + """Check if replica can accept a new request""" + return ( + self.state == ReplicaState.HEALTHY + and self.active_requests < self.max_concurrent_requests + ) + + def __repr__(self) -> str: + return ( + f"Replica(idx={self.idx}, state={self.state.value}, " + f"active={self.active_requests}/{self.max_concurrent_requests}, " + f"queue={self.request_queue.qsize()})" + ) diff --git a/src/forge/controller/service.py b/src/forge/controller/service.py index 13c58db36..6be098639 100644 --- a/src/forge/controller/service.py +++ b/src/forge/controller/service.py @@ -32,95 +32,24 @@ ... result = await service.my_endpoint(arg1, arg2) """ - import asyncio import contextvars import logging import pprint -import time import uuid -from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Any, Callable, Coroutine, Dict, List, Optional +from typing import Dict, List from monarch._src.actor.endpoint import EndpointProperty -from monarch.actor import ActorError, ProcMesh -from forge.controller import RecoverableProcMesh +from forge.controller.replica import Replica, ReplicaMetrics, ServiceRequest from forge.types import ServiceConfig logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -# TODO - tie this into metric logger when it exists -@dataclass -class ReplicaMetrics: - """ - Metrics collection for a single replica instance. - - Tracks request counts, timing metrics, current state, and session assignments - for performance monitoring and autoscaling decisions. - - Attributes: - replica_idx: Unique identifier for this replica - total_requests: Total number of requests processed - successful_requests: Number of successfully completed requests - failed_requests: Number of failed requests - request_times: Sliding window of request start timestamps - request_latencies: Sliding window of request completion latencies - active_requests: Currently processing requests - queue_depth: Number of pending requests in queue - assigned_sessions: Number of sessions assigned to this replica - """ - - replica_idx: int - # Request metrics - total_requests: int = 0 - successful_requests: int = 0 - failed_requests: int = 0 - # Timing metrics (sliding window) - request_times: deque = field(default_factory=lambda: deque(maxlen=100)) - request_latencies: deque = field(default_factory=lambda: deque(maxlen=100)) - # Current state - active_requests: int = 0 - queue_depth: int = 0 - # Session metrics - assigned_sessions: int = 0 - - def add_request_start(self, timestamp: float): - """Record when a request starts processing.""" - self.request_times.append(timestamp) - self.total_requests += 1 - - def add_request_completion(self, start_time: float, success: bool): - """Record when a request completes.""" - latency = time.time() - start_time - self.request_latencies.append(latency) - if success: - self.successful_requests += 1 - else: - self.failed_requests += 1 - - def get_request_rate(self, window_seconds: float = 60.0) -> float: - """Get requests per second over the last window_seconds.""" - now = time.time() - cutoff = now - window_seconds - recent_requests = [t for t in self.request_times if t >= cutoff] - return len(recent_requests) / window_seconds if window_seconds > 0 else 0.0 - - def get_avg_latency(self, window_requests: int = 50) -> float: - """Get average latency over the last N requests.""" - if not self.request_latencies: - return 0.0 - recent_latencies = list(self.request_latencies)[-window_requests:] - return sum(recent_latencies) / len(recent_latencies) - - def get_capacity_utilization(self, max_concurrent: int) -> float: - """Get current capacity utilization (0.0 to 1.0).""" - return self.active_requests / max_concurrent if max_concurrent > 0 else 0.0 - - +# TODO - tie this into metrics logger when it exists. @dataclass class ServiceMetrics: """ @@ -153,79 +82,67 @@ def get_total_request_rate(self, window_seconds: float = 60.0) -> float: for metrics in self.replica_metrics.values() ) - def get_avg_queue_depth(self) -> float: + def get_avg_queue_depth(self, replicas: List) -> float: """Get average queue depth across all healthy replicas.""" - healthy_metrics = [ - m - for m in self.replica_metrics.values() - if m.replica_idx < self.healthy_replicas - ] - if not healthy_metrics: + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: return 0.0 - return sum(m.queue_depth for m in healthy_metrics) / len(healthy_metrics) + total_queue_depth = sum(r.qsize() for r in healthy_replicas) + return total_queue_depth / len(healthy_replicas) def get_avg_capacity_utilization(self, replicas: List) -> float: """Get average capacity utilization across all healthy replicas.""" - healthy_replicas = [r for r in replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in replicas if r.healthy] if not healthy_replicas: return 0.0 - - utilizations = [] - for replica in healthy_replicas: - if replica.idx in self.replica_metrics: - metrics = self.replica_metrics[replica.idx] - utilization = metrics.get_capacity_utilization( - replica.max_concurrent_requests - ) - utilizations.append(utilization) - - return sum(utilizations) / len(utilizations) if utilizations else 0.0 + total_utilization = sum(r.capacity_utilization for r in healthy_replicas) + return total_utilization / len(healthy_replicas) def get_sessions_per_replica(self) -> float: - """Get average sessions per healthy replica.""" - if self.healthy_replicas == 0: + """Get average sessions per replica.""" + if self.total_replicas == 0: return 0.0 - return self.total_sessions / self.healthy_replicas + return self.total_sessions / self.total_replicas -@dataclass -class Replica: - proc_mesh: RecoverableProcMesh - actor: Any - idx: int - request_queue: asyncio.Queue[dict] = field(default_factory=asyncio.Queue) - active_requests: int = 0 - max_concurrent_requests: int = 10 - _processor_running: bool = False - metadata: dict = field(default_factory=dict) +# Context variable for session state +_session_context = contextvars.ContextVar("session_context") @dataclass class Session: + """Simple session data holder.""" + session_id: str -# Global context variable for session state -# This is used to propagate session state across async tasks -_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar( - "session_context", default=None -) +class SessionContext: + """ + Async context manager for stateful service sessions with automatic lifecycle management. + Provides a convenient way to maintain stateful connections to replicas across multiple + requests. Sessions ensure that all requests within the context are routed to the same + replica, enabling stateful interactions while handling session lifecycle automatically. -class SessionContext: - """Context manager for service sessions using context variables.""" + Example: + + >>> async with service.session() as session: + ... # All calls within this block use the same replica + ... result1 = await service.my_endpoint(arg1) + ... result2 = await service.another_endpoint(result1) + + """ - def __init__(self, service: "Service", **session_kwargs): + def __init__(self, service: "Service"): self.service = service self.session_id: str | None = None - self.session_kwargs = session_kwargs self._token = None async def __aenter__(self): """Start a session and set context variables.""" self.session_id = await self.service.start_session() # Set context for this async task - context_value = {"session_id": self.session_id, "kwargs": self.session_kwargs} + context_value = {"session_id": self.session_id} self._token = _session_context.set(context_value) return self @@ -299,56 +216,46 @@ def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): # Initialize metrics collection self._metrics = ServiceMetrics() - - # Autoscaling state - self._last_scale_up_time = 0.0 - self._last_scale_down_time = 0.0 - self._low_utilization_start_time = None self._health_task = None self._shutdown_requested = False # Replica initialization queue - self._replicas_to_init = [] + self._replicas_to_recover = [] # For all endpoints within the actor_def, create an interface from it self._endpoints = [] for func_name in dir(actor_def): func = getattr(actor_def, func_name) if isinstance(func, EndpointProperty): - logger.debug("Registering endpoint %s", func_name) + logger.debug(f"Registering endpoint {func_name}") self._endpoints.append(func_name) # Dynamically add this endpoint method to the Service class self._add_endpoint_method(func_name) async def __initialize__(self): - logger.debug("Starting service up with %d replicas.", self._cfg.num_replicas) + """Initializes the service and starts the health loop.""" + logger.debug(f"Starting service up with {self._cfg.num_replicas} replicas.") replicas = [] num_replicas = self._cfg.num_replicas for i in range(num_replicas): - mesh = RecoverableProcMesh(proc_config=self._cfg.to_process_config()) replica = Replica( - proc_mesh=mesh, - actor=None, idx=len(self._replicas) + i, + proc_config=self._cfg.to_process_config(), max_concurrent_requests=self._cfg.replica_max_concurrent_requests, + return_first_rank_result=self._cfg.return_first_rank_result, + actor_def=self._actor_def, + actor_args=self._actor_args, + actor_kwargs=self._actor_kwargs, ) replicas.append(replica) - # Initializing should only happen in the health_loop - # and during the first initialization. - # If multiple parts of the code try to initialize replicas at - # the same time, it can cause nasty race conditions - # (e.g., double initialization, inconsistent state, or resource conflicts). - # By funneling all replica initialization through a single queue and the - # health loop, we ensure safe, serialized initialization. logger.debug( - "Queued %d replicas for initialization. Total replicas: %d", - num_replicas, - len(self._replicas), + f"Queued {num_replicas} replicas for initialization. Total replicas: {len(self._replicas)}" ) - self._replicas_to_init.extend(replicas) - await self._maybe_init_replicas() - self._replicas.extend(replicas) + + # Initialize all replicas in parallel + await asyncio.gather(*[r.initialize() for r in replicas]) + self._replicas = replicas # Start the health loop in the background self._health_task = asyncio.create_task( @@ -390,140 +297,40 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): """ # Check context variables for session state if no explicit sess_id if sess_id is None: - ctx = _session_context.get() + ctx = _session_context.get(None) if ctx: sess_id = ctx["session_id"] - routing_hints = ctx["kwargs"] - else: - routing_hints = {} - else: - routing_hints = {} - - replica = await self._get_replica(sess_id, **routing_hints) - - # Create a request object to queue - request = { - "sess_id": sess_id, - "function": function, - "args": args, - "kwargs": kwargs, - "future": asyncio.Future(), - } - # Queue the request - await replica.request_queue.put(request) - # Ensure the replica has a processor running - self._ensure_processor_running(replica) + replica = await self._get_replica(sess_id) + + # Create a ServiceRequest object to queue + request = ServiceRequest( + session_id=sess_id, + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) + + # Queue the request using replica's method + await replica.enqueue_request(request) # Wait for the result try: - return await request["future"] + return await request.future except Exception as e: # If the replica failed, try to retry once - if not replica.proc_mesh.healthy: + if not replica.healthy: logger.debug( - "Replica %d failed during request, retrying on healthy replica", + "Replica %d failed during request, retrying on healthy replica. Exception: %s", replica.idx, + e, ) return await self._retry_request_on_healthy_replica( sess_id, function, *args, **kwargs ) raise - def _ensure_processor_running(self, replica: Replica): - """Ensures a persistent processor is running for this replica.""" - if not replica._processor_running: - replica._processor_running = True - asyncio.create_task(self._persistent_processor(replica)) - - async def _persistent_processor(self, replica: Replica): - """Persistent processor that continuously handles requests for a replica.""" - try: - while replica.proc_mesh.healthy: - try: - # Wait for a request with timeout to check health periodically - request = await asyncio.wait_for( - replica.request_queue.get(), timeout=1.0 - ) - - # Check if we have capacity - if replica.active_requests >= replica.max_concurrent_requests: - # Put the request back and wait - await replica.request_queue.put(request) - await asyncio.sleep(0.1) - continue - - # Process the request - asyncio.create_task(self._process_single_request(replica, request)) - - except asyncio.TimeoutError: - # No requests, continue to check health - continue - except Exception as e: - logger.error( - "Error in persistent processor for replica %d: %s", - replica.idx, - e, - ) - break - finally: - replica._processor_running = False - # Migrate any remaining requests to healthy replicas - await self._migrate_remaining_requests(replica) - - async def _process_single_request(self, replica: Replica, request: dict): - """Processes a single request.""" - start_time = time.time() - replica.active_requests += 1 - - # Get or create metrics for this replica - if replica.idx not in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica.idx] = ReplicaMetrics(replica.idx) - - replica_metrics = self._metrics.replica_metrics[replica.idx] - replica_metrics.add_request_start(start_time) - replica_metrics.active_requests = replica.active_requests - - try: - # Get the actor and endpoint - actor = replica.actor - endpoint_func = getattr(actor, request["function"]) - - # Execute the request - success = True - try: - result = await endpoint_func.call(*request["args"], **request["kwargs"]) - if ( - self._cfg.return_first_rank_result - and hasattr(result, "_values") - and result._values - ): - result = result._values[0] - request["future"].set_result(result) - except ActorError as e: - logger.debug("Got failure on replica %d. Error:\n%s", replica.idx, e) - replica.proc_mesh.mark_failed() - # Unwrap the ActorError into its raw exception. - request["future"].set_result(e.exception) - success = False - except Exception as e: - logger.debug( - "Got unexpected error on replica %d. Error:\n%s", replica.idx, e - ) - replica.proc_mesh.mark_failed() - request["future"].set_result(e) - success = False - - # Record completion metrics - replica_metrics.add_request_completion(start_time, success) - - # Mark task as done - replica.request_queue.task_done() - - finally: - replica.active_requests -= 1 - replica_metrics.active_requests = replica.active_requests - async def _retry_request_on_healthy_replica( self, sess_id: str | None, function: str, *args, **kwargs ): @@ -558,13 +365,13 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): # Find healthy replicas healthy_replicas = [ - r for r in self._replicas if r.proc_mesh.healthy and r != failed_replica + r for r in self._replicas if r.healthy and r != failed_replica ] if not healthy_replicas: # No healthy replicas, fail all requests for request in migrated_requests: - request["future"].set_exception( + request.future.set_exception( RuntimeError("No healthy replicas available") ) return @@ -572,11 +379,10 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): # Distribute requests among healthy replicas for i, request in enumerate(migrated_requests): target_replica = healthy_replicas[i % len(healthy_replicas)] - await target_replica.request_queue.put(request) - self._ensure_processor_running(target_replica) + await target_replica.enqueue_request(request) # Update session mapping if needed - sess_id = request["sess_id"] + sess_id = request.session_id if ( sess_id in self._session_replica_map and self._session_replica_map[sess_id] == failed_replica.idx @@ -608,35 +414,20 @@ async def start_session(self) -> str: return sess_id - def session(self, **kwargs) -> SessionContext: + def session(self) -> SessionContext: """Returns a context manager for session-based calls.""" - return SessionContext(self, **kwargs) + return SessionContext(self) def _update_service_metrics(self): """Updates service-level metrics.""" self._metrics.total_sessions = len(self._active_sessions) self._metrics.total_replicas = len(self._replicas) - self._metrics.healthy_replicas = sum( - 1 for r in self._replicas if r.proc_mesh.healthy - ) - - # Update queue depths for all replicas + self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + # Store direct references to replica metrics for aggregation + self._metrics.replica_metrics = {} for replica in self._replicas: - if replica.idx not in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica.idx] = ReplicaMetrics(replica.idx) - - replica_metrics = self._metrics.replica_metrics[replica.idx] - replica_metrics.queue_depth = replica.request_queue.qsize() - replica_metrics.active_requests = replica.active_requests - - # Update session assignments per replica - session_counts = defaultdict(int) - for sess_id, replica_idx in self._session_replica_map.items(): - session_counts[replica_idx] += 1 - - for replica_idx, count in session_counts.items(): - if replica_idx in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica_idx].assigned_sessions = count + # Use the replica's own metrics directly + self._metrics.replica_metrics[replica.idx] = replica.metrics def get_metrics(self) -> ServiceMetrics: """ @@ -680,7 +471,7 @@ def get_metrics_summary(self) -> dict: "healthy_replicas": self._metrics.healthy_replicas, "total_replicas": self._metrics.total_replicas, "total_request_rate": self._metrics.get_total_request_rate(), - "avg_queue_depth": self._metrics.get_avg_queue_depth(), + "avg_queue_depth": self._metrics.get_avg_queue_depth(self._replicas), "avg_capacity_utilization": self._metrics.get_avg_capacity_utilization( self._replicas ), @@ -689,17 +480,26 @@ def get_metrics_summary(self) -> dict: "replicas": {}, } - for replica_idx, metrics in self._metrics.replica_metrics.items(): - summary["replicas"][replica_idx] = { + for replica in self._replicas: + metrics = replica.metrics + + # Count sessions assigned to this replica + assigned_sessions = sum( + 1 + for replica_idx in self._session_replica_map.values() + if replica_idx == replica.idx + ) + + summary["replicas"][replica.idx] = { "total_requests": metrics.total_requests, "successful_requests": metrics.successful_requests, "failed_requests": metrics.failed_requests, "request_rate": metrics.get_request_rate(), "avg_latency": metrics.get_avg_latency(), - "active_requests": metrics.active_requests, - "queue_depth": metrics.queue_depth, - "assigned_sessions": metrics.assigned_sessions, - "capacity_utilization": metrics.get_capacity_utilization(10), + "active_requests": replica.active_requests, # Get from replica + "queue_depth": replica.qsize(), + "assigned_sessions": assigned_sessions, # Calculate from session map + "capacity_utilization": replica.capacity_utilization, # Get from replica } return summary @@ -743,13 +543,13 @@ async def _health_loop(self, poll_rate_s: float): """ while not self._shutdown_requested: - # Process any replicas that need initialization - await self._maybe_init_replicas() + # Process any replicas that need recovery + await self._recover_replicas() # Check for failed replicas and recover them failed_replicas = [] for replica in self._replicas: - if replica.proc_mesh.failed: + if replica.failed: failed_replicas.append(replica) if any(failed_replicas): @@ -758,19 +558,13 @@ async def _health_loop(self, poll_rate_s: float): len(failed_replicas), pprint.pformat(failed_replicas), ) - self._replicas_to_init.extend(failed_replicas) + self._replicas_to_recover.extend(failed_replicas) await asyncio.sleep(poll_rate_s) - async def _custom_replica_routing( - self, sess_id: str | None, **kwargs - ) -> Optional[Replica]: - """Hook for custom routing logic. Override in subclasses to implement custom routing.""" - return None - def _get_next_replica(self) -> "Replica": """Get the next replica using round-robin selection.""" - healthy_replicas = [r for r in self._replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in self._replicas if r.healthy] if not healthy_replicas: raise RuntimeError("No healthy replicas available for load balancing") @@ -780,25 +574,15 @@ def _get_next_replica(self) -> "Replica": def _get_least_loaded_replica(self) -> "Replica": """Get the replica with the lowest load.""" - healthy_replicas = [r for r in self._replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in self._replicas if r.healthy] if not healthy_replicas: raise RuntimeError("No healthy replicas available for session assignment") - # Load = active_requests + queue_depth - def get_load(replica: "Replica") -> int: - return replica.active_requests + replica.request_queue.qsize() - - return min(healthy_replicas, key=get_load) - - async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": - """Get a replica for the given session ID, with optional custom routing hints.""" - # Try custom routing first if hints are provided - if kwargs: - custom_result = await self._custom_replica_routing(sess_id, **kwargs) - if custom_result is not None: - return custom_result + # Use the replica's current_load property + return min(healthy_replicas, key=lambda replica: replica.current_load) - # Default routing logic + async def _get_replica(self, sess_id: str | None) -> "Replica": + """Get a replica for the given session ID.""" if sess_id is None: # No session, use round-robin load balancing replica = self._get_next_replica() @@ -809,7 +593,7 @@ async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": replica_idx = self._session_replica_map[sess_id] # Find the replica with this index for replica in self._replicas: - if replica.idx == replica_idx and replica.proc_mesh.healthy: + if replica.idx == replica_idx and replica.healthy: return replica # If the replica is no longer healthy, remove from session map and reassign del self._session_replica_map[sess_id] @@ -838,151 +622,37 @@ async def stop(self): except asyncio.CancelledError: logger.info("Health loop task cancelled.") + # Stop all replicas using their stop method await asyncio.gather( - *[replica.proc_mesh.stop() for replica in self._replicas], + *[replica.stop() for replica in self._replicas], return_exceptions=True, ) - async def _maybe_init_replicas(self): - """Initializes replicas that are queued for initialization.""" - if not self._replicas_to_init: + async def _recover_replicas(self): + """Recovers unhealthy queued replicas.""" + if not self._replicas_to_recover: return - logger.debug("Init replicas: %s", pprint.pformat(self._replicas_to_init)) - - def _recover_hook( - replica: Replica, - ) -> Callable[[ProcMesh], Coroutine[Any, Any, None]]: - async def inner_hook(proc_mesh: ProcMesh) -> None: - if "name" in self._actor_kwargs: - actor_name = self._actor_kwargs.pop("name") - else: - actor_name = self._actor_def.__name__ - # TODO - expand support so name can stick within kwargs - actor = await proc_mesh.spawn( - actor_name, - self._actor_def, - *self._actor_args, - **self._actor_kwargs, - ) - replica.actor = actor - if hasattr(actor, "setup"): - await actor.setup.call() - - return inner_hook - - await asyncio.gather( - *[ - replica.proc_mesh.spawn(_recover_hook(replica)) - for replica in self._replicas_to_init - ] - ) - self._replicas_to_init.clear() - - async def _scale_up(self, num_replicas: int = 1): - """ - Scales up the service by adding new replicas. - - Creates new replica instances with their own process meshes and queues them - for initialization. The replicas will be initialized asynchronously by the - health loop to avoid blocking the scaling operation. - - Args: - num_replicas: Number of replicas to add (default: 1) - - Note: - Replicas are queued for initialization rather than initialized immediately - to prevent blocking during scaling operations. - """ - logger.debug("Scaling up with %d replicas.", num_replicas) - new_replicas = [] - for i in range(num_replicas): - mesh = RecoverableProcMesh( - self._cfg.procs_per_replica, - ) - replica = Replica( - proc_mesh=mesh, - actor=None, - idx=len(self._replicas) + i, - max_concurrent_requests=self._cfg.replica_max_concurrent_requests, - ) - new_replicas.append(replica) - - # Add to the initialization queue instead of initializing immediately - self._replicas_to_init.extend(new_replicas) - self._replicas.extend(new_replicas) logger.debug( - "Queued %d replicas for initialization. Total replicas: %d", - num_replicas, - len(self._replicas), + "Recovering replicas: %s", pprint.pformat(self._replicas_to_recover) ) - async def _scale_down_replicas(self, num_replicas: int = 1): - """ - Scales down the service by intelligently removing replicas. - - Prioritizes removal of unhealthy replicas first, then selects healthy replicas - with the lowest load. Migrates all workload (sessions and queued requests) - from removed replicas to remaining healthy replicas. - - Args: - num_replicas: Number of replicas to remove (default: 1) - - Note: - # Test context manager usage - async with service.session(): - await service.incr() - await service.incr() - result = await service.value() - assert result == 2 - - Sessions are reassigned on their next request rather than immediately - to avoid disrupting active workloads. - """ - logger.debug("Scaling down by %d replicas.", num_replicas) - - # Find replicas to remove (prefer unhealthy ones first, then least loaded) - replicas_to_remove = [] - - # First, try to remove unhealthy replicas - unhealthy_replicas = [r for r in self._replicas if not r.proc_mesh.healthy] - for replica in unhealthy_replicas[:num_replicas]: - replicas_to_remove.append(replica) - - # If we need more, remove healthy replicas with least load - remaining_to_remove = num_replicas - len(replicas_to_remove) - if remaining_to_remove > 0: - healthy_replicas = [ - r - for r in self._replicas - if r.proc_mesh.healthy and r not in replicas_to_remove - ] - # Sort by load (queue depth + active requests) - healthy_replicas.sort( - key=lambda r: r.request_queue.qsize() + r.active_requests - ) - - for replica in healthy_replicas[:remaining_to_remove]: - replicas_to_remove.append(replica) - - # Migrate sessions and requests from replicas being removed - for replica in replicas_to_remove: - await self._migrate_replica_workload(replica) - - # Stop the replica + async def _recover(replica): + """Recover a single replica.""" try: - await replica.proc_mesh.stop() + await replica.recover() + logger.debug("Successfully recovered replica %d", replica.idx) except Exception as e: - logger.warning("Error stopping replica %d: %s", replica.idx, e) + logger.error("Failed to recover replica %d: %s", replica.idx, e) + replica.mark_failed() - # Remove from replicas list - self._replicas.remove(replica) - - # Update replica indices - for i, replica in enumerate(self._replicas): - replica.idx = i + recovery_tasks = [ + asyncio.create_task(_recover(replica)) + for replica in self._replicas_to_recover + ] - logger.debug("Scale down complete. Remaining replicas: %d", len(self._replicas)) + await asyncio.gather(*recovery_tasks, return_exceptions=True) + self._replicas_to_recover.clear() async def _migrate_replica_workload(self, replica_to_remove: Replica): """Migrates all workload from a replica that's being removed.""" diff --git a/tests/test_service.py b/tests/test_service.py index 7283aeeee..b49818cfc 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -152,6 +152,80 @@ async def worker(increments: int): # Fault Tolerance Tests +@pytest.mark.timeout(20) +@pytest.mark.asyncio +async def test_recovery_state_transitions(): + """Test replica state transitions during failure and recovery.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=1, health_poll_rate=0.1) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) + + try: + # Initially replica should be healthy + replica = service._replicas[0] + assert replica.state.value == "HEALTHY" + assert replica.healthy is True + assert replica.failed is False + + # Create session and make a successful call + session = await service.start_session() + await service.incr(session) + result = await service.value(session) + assert result == 1 + + # Cause failure - this should transition to RECOVERING + error_result = await service.fail_me(session) + assert isinstance(error_result, RuntimeError) + + # Replica should now be in RECOVERING state + assert replica.state.value == "RECOVERING" + assert replica.healthy is False + assert replica.failed is True + + # Wait for health loop to detect and attempt recovery + # The health loop runs every 0.1s, so give it some time + max_wait_time = 5.0 # 5 seconds max wait + wait_interval = 0.1 + elapsed = 0.0 + + # Wait for replica to either recover (HEALTHY) or fail completely (UNHEALTHY) + while elapsed < max_wait_time: + await asyncio.sleep(wait_interval) + elapsed += wait_interval + + if replica.state.value in ["HEALTHY", "UNHEALTHY"]: + break + + # After recovery, replica should be healthy again + # (unless recovery failed, in which case it would be UNHEALTHY) + assert replica.state.value in ["HEALTHY", "UNHEALTHY"] + + if replica.state.value == "HEALTHY": + # If recovery succeeded, verify we can make calls again + assert replica.healthy is True + assert replica.failed is False + + # Test that we can make new calls after recovery + new_session = await service.start_session() + await service.incr(new_session) + result = await service.value(new_session) + assert ( + result is not None + ) # Should get a result (counter starts at 0 in new actor) + + elif replica.state.value == "UNHEALTHY": + # If recovery failed, verify failed state + assert replica.healthy is False + assert replica.failed is True + + # Verify that the state transition path was correct + # (We can't guarantee the exact end state due to potential flakiness in test environments, + # but we can verify the replica went through the expected transition) + logger.info(f"Final replica state: {replica.state.value}") + + finally: + await service.stop() + + @pytest.mark.timeout(15) @pytest.mark.asyncio async def test_replica_failure_and_recovery(): @@ -172,7 +246,7 @@ async def test_replica_failure_and_recovery(): # Replica should be marked as failed failed_replica = service._replicas[original_replica_idx] - assert not failed_replica.proc_mesh.healthy + assert not failed_replica.healthy # Session should be reassigned on next call await service.incr(session) @@ -183,7 +257,7 @@ async def test_replica_failure_and_recovery(): new_session = await service.start_session() await service.incr(new_session) assigned_replica = service._replicas[service._session_replica_map[new_session]] - assert assigned_replica.proc_mesh.healthy + assert assigned_replica.healthy finally: await service.stop() @@ -195,7 +269,7 @@ async def test_replica_failure_and_recovery(): @pytest.mark.timeout(10) @pytest.mark.asyncio async def test_metrics_collection(): - """Test comprehensive metrics collection.""" + """Test metrics collection.""" cfg = ServiceConfig(procs_per_replica=1, num_replicas=2) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) From 38bcad5a38ab475625f22a2c1602719ff526d63c Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Mon, 25 Aug 2025 13:35:36 -0700 Subject: [PATCH 07/15] Use positional actor_def to avoid actor_args being considered also actor_def (#67) --- src/forge/controller/replica.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py index 7403bcb31..1fea467b3 100644 --- a/src/forge/controller/replica.py +++ b/src/forge/controller/replica.py @@ -168,7 +168,7 @@ async def initialize(self): # Spawn the actor await self.spawn_actor( - actor_def=self.actor_def, + self.actor_def, *self.actor_args, **self.actor_kwargs, ) From 1a7dd9aa22c7d10b15b49d3698c19c002a9a4538 Mon Sep 17 00:00:00 2001 From: Danning XIE <24580222+DNXie@users.noreply.github.com> Date: Mon, 25 Aug 2025 14:15:29 -0700 Subject: [PATCH 08/15] Merge rewards files (#68) * Add reward interface, math reward, unit tests * refactor rewards: merge into one file * remove file accidentally had --- .../data/{rewards/math.py => rewards.py} | 16 ++++++++++++++++ src/forge/data/rewards/thinking.py | 19 ------------------- tests/unit_tests/rl/test_math_reward.py | 2 +- tests/unit_tests/rl/test_thinking_reward.py | 2 +- 4 files changed, 18 insertions(+), 21 deletions(-) rename src/forge/data/{rewards/math.py => rewards.py} (79%) delete mode 100644 src/forge/data/rewards/thinking.py diff --git a/src/forge/data/rewards/math.py b/src/forge/data/rewards.py similarity index 79% rename from src/forge/data/rewards/math.py rename to src/forge/data/rewards.py index 06dd6cbc6..72973ae95 100644 --- a/src/forge/data/rewards/math.py +++ b/src/forge/data/rewards.py @@ -54,3 +54,19 @@ def __call__(self, prompt: str, response: str, target: str) -> float: if abs(expected_answer - model_answer) < self.tolerance: return 1.0 # Correct answer return 0.0 # Incorrect answer + + +class ThinkingReward(Reward): + """Reward class for evaluating use of tags in reasoning.""" + + def __init__(self, reward_value: float = 0.5): + self.reward_value = reward_value + + def __call__( + self, prompt: str, response: str, target: Optional[str] = None + ) -> float: + """Check if response contains ... tags.""" + resp = response.lower() + if "" in resp and "" in resp: + return self.reward_value + return 0.0 diff --git a/src/forge/data/rewards/thinking.py b/src/forge/data/rewards/thinking.py deleted file mode 100644 index 8c4eb6852..000000000 --- a/src/forge/data/rewards/thinking.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Optional - -from forge.interfaces import Reward - - -class ThinkingReward(Reward): - """Reward class for evaluating use of tags in reasoning.""" - - def __init__(self, reward_value: float = 0.5): - self.reward_value = reward_value - - def __call__( - self, prompt: str, response: str, target: Optional[str] = None - ) -> float: - """Check if response contains ... tags.""" - resp = response.lower() - if "" in resp and "" in resp: - return self.reward_value - return 0.0 diff --git a/tests/unit_tests/rl/test_math_reward.py b/tests/unit_tests/rl/test_math_reward.py index a109492dd..2f3521b4d 100644 --- a/tests/unit_tests/rl/test_math_reward.py +++ b/tests/unit_tests/rl/test_math_reward.py @@ -7,7 +7,7 @@ import unittest from unittest import mock -from forge.data.rewards.math import MathReward +from forge.data.rewards import MathReward class TestMathReward(unittest.TestCase): diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/rl/test_thinking_reward.py index d4af16969..592ceb896 100644 --- a/tests/unit_tests/rl/test_thinking_reward.py +++ b/tests/unit_tests/rl/test_thinking_reward.py @@ -6,7 +6,7 @@ import unittest -from forge.data.rewards.thinking import ThinkingReward +from forge.data.rewards import ThinkingReward class TestThinkingReward(unittest.TestCase): From ce1ed98434654754b1171f67a137eeadf53f7eb9 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:00:11 -0500 Subject: [PATCH 09/15] [Service] Turns Service into an Actor and splits service into its own files (#69) * initial commit for replica * clean up * phase out service for service v2 * remove v2 * remove v2 from spawn * more minor cleanups * remove comment * remove comment * initial commit of ServiceEndpoint * tests work * simplify and unify replica initialization * stop the underlying service proc * split out components into their own files * address comments * address comments * add capacity semaphore * rebasing changes * fix test * logger changes * fix sess_id kwarg * makes _call its own implementation * docstring fix * add comment on serviceinterface --------- Co-authored-by: Allen Wang --- src/forge/controller/__init__.py | 4 + src/forge/controller/interface.py | 188 +++++++++++++++++++ src/forge/controller/metrics.py | 73 ++++++++ src/forge/controller/replica.py | 43 ++--- src/forge/controller/service.py | 277 ++++++++++++---------------- src/forge/controller/spawn.py | 22 ++- tests/test_service.py | 295 ++++++++++++++++++++++-------- 7 files changed, 629 insertions(+), 273 deletions(-) create mode 100644 src/forge/controller/interface.py create mode 100644 src/forge/controller/metrics.py diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index a800eb14c..66906fd6a 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from .actor import ForgeActor +from .interface import ServiceInterface, Session, SessionContext from .proc_mesh import get_proc_mesh, spawn_actors from .service import Service, ServiceConfig from .spawn import spawn_service @@ -11,6 +12,9 @@ __all__ = [ "Service", "ServiceConfig", + "ServiceInterface", + "Session", + "SessionContext", "spawn_service", "spawn_actors", "get_proc_mesh", diff --git a/src/forge/controller/interface.py b/src/forge/controller/interface.py new file mode 100644 index 000000000..3b6b4b687 --- /dev/null +++ b/src/forge/controller/interface.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Service interface and session management. + +This module provides the user-facing API for interacting with distributed services, +including session management, context propagation, and dynamic endpoint registration. +""" + +import contextvars +import logging +from dataclasses import dataclass +from typing import Generic, List, ParamSpec, TypeVar + +from monarch._src.actor.endpoint import EndpointProperty + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +P = ParamSpec("P") +R = TypeVar("R") + + +@dataclass +class Session: + """Simple session data holder.""" + + session_id: str + + +# Context variable for session state +_session_context = contextvars.ContextVar("session_context") + + +class SessionContext: + """ + Async context manager for stateful service sessions with automatic lifecycle management. + + Provides a convenient way to maintain stateful connections to replicas across multiple + requests. Sessions ensure that all requests within the context are routed to the same + replica, enabling stateful interactions while handling session lifecycle automatically. + + Example: + + >>> async with service.session() as session: + ... # All calls within this block use the same replica + ... result1 = await service.my_endpoint(arg1) + ... result2 = await service.another_endpoint(result1) + + """ + + def __init__(self, service: "ServiceInterface"): + self.service = service + self.session_id: str | None = None + self._token = None + + async def __aenter__(self): + """Start a session and set context variables.""" + self.session_id = await self.service.start_session() + # Set context for this async task + context_value = {"session_id": self.session_id} + self._token = _session_context.set(context_value) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Terminate the session and restore context.""" + if self._token: + _session_context.reset(self._token) + if self.session_id: + await self.service.terminate_session(self.session_id) + self.session_id = None + + +class ServiceEndpoint(Generic[P, R]): + """An endpoint object specific to services. + + This loosely mimics the Endpoint APIs exposed in Monarch, with + a few key differences: + - Only choose and call are retained (dropping stream and call_one) + - Call returns a list directly rather than a ValueMesh. + + These changes are made with Forge use cases in mind, but can + certainly be expanded/adapted in the future. + + """ + + def __init__(self, actor_mesh, endpoint_name: str): + self.actor_mesh = actor_mesh + self.endpoint_name = endpoint_name + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.actor_mesh.call.call_one( + sess_id, self.endpoint_name, *args, **kwargs + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.actor_mesh.call_all.call_one( + self.endpoint_name, *args, **kwargs + ) + return result + + +class ServiceInterface: + """ + A lightweight interface to a Service Actor running on a single-node mesh. + + This interface holds references to the proc_mesh and actor_mesh (both of size 1) + and exposes its user-defined actor endpoints as ServiceEndpoint objects that + route through the Service Actor's _call and _call_all endpoints. + + The ServiceInterface acts as the handle that is returned to end clients, + providing a simple interface that makes actual calls to the Service Actor. + + This is also needed to simplify serializing a handle to the service, in case + we want to pass this to other actors in the future. + + """ + + def __init__(self, _proc_mesh, _service, actor_def): + self._proc_mesh = _proc_mesh + self._service = _service + self.actor_def = actor_def + + # Dynamically create ServiceEndpoint objects for user's actor endpoints + # Inspect the actor_def directly to find endpoints + for attr_name in dir(actor_def): + attr_value = getattr(actor_def, attr_name) + if isinstance(attr_value, EndpointProperty): + # Create a ServiceEndpoint that will route through the Service Actor + endpoint = ServiceEndpoint(self._service, attr_name) + setattr(self, attr_name, endpoint) + + # Session management methods - handled by ServiceInterface + async def start_session(self) -> str: + """Starts a new session for stateful request handling.""" + return await self._service.start_session.call_one() + + async def terminate_session(self, sess_id: str): + """Terminates an active session and cleans up associated resources.""" + return await self._service.terminate_session.call_one(sess_id) + + def session(self) -> "SessionContext": + """Returns a context manager for session-based calls.""" + return SessionContext(self) + + # Service control methods - forwarded to Service Actor + async def stop(self): + """Stops the service gracefully.""" + # First stop the service + await self._service.stop.call_one() + # Then stop its underlying proc + await self._proc_mesh.stop() + + # Metrics methods - forwarded to Service Actor + async def get_metrics(self): + """Get comprehensive service metrics for monitoring and analysis.""" + return await self._service.get_metrics.call_one() + + async def get_metrics_summary(self): + """Get a summary of key metrics for monitoring and debugging.""" + return await self._service.get_metrics_summary.call_one() + + # Testing method - forwarded to Service Actor + def _get_internal_state(self): + """ + Get comprehensive internal state for testing purposes. + + Returns: + dict: Complete internal state including sessions, replicas, and metrics + """ + return self._service._get_internal_state.call_one() + + def __getattr__(self, name: str): + """Forward all other attribute access to the underlying Service Actor.""" + # Forward everything else to the _service + if hasattr(self._service, name): + return getattr(self._service, name) + + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) diff --git a/src/forge/controller/metrics.py b/src/forge/controller/metrics.py new file mode 100644 index 000000000..dbd145862 --- /dev/null +++ b/src/forge/controller/metrics.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Service metrics collection and aggregation. + +This module provides comprehensive metrics tracking for distributed services, +including per-replica performance data, service-wide aggregations, and +health status information. +""" + +from dataclasses import dataclass, field +from typing import Dict, List + +from forge.controller.replica import ReplicaMetrics + + +# TODO - tie this into metrics logger when it exists. +@dataclass +class ServiceMetrics: + """ + Aggregated metrics collection for the entire service. + + Provides service-wide visibility into performance, health, and scaling metrics + by aggregating data from all replica instances. + + Attributes: + replica_metrics: Per-replica metrics indexed by replica ID + total_sessions: Number of active sessions across all replicas + healthy_replicas: Number of currently healthy replicas + total_replicas: Total number of replicas (healthy + unhealthy) + last_scale_event: Timestamp of the last scaling operation + """ + + # Replica metrics + replica_metrics: Dict[int, ReplicaMetrics] = field(default_factory=dict) + # Service-level metrics + total_sessions: int = 0 + healthy_replicas: int = 0 + total_replicas: int = 0 + # Time-based metrics + last_scale_event: float = 0.0 + + def get_total_request_rate(self, window_seconds: float = 60.0) -> float: + """Get total requests per second across all replicas.""" + return sum( + metrics.get_request_rate(window_seconds) + for metrics in self.replica_metrics.values() + ) + + def get_avg_queue_depth(self, replicas: List) -> float: + """Get average queue depth across all healthy replicas.""" + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: + return 0.0 + total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas) + return total_queue_depth / len(healthy_replicas) + + def get_avg_capacity_utilization(self, replicas: List) -> float: + """Get average capacity utilization across all healthy replicas.""" + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: + return 0.0 + total_utilization = sum(r.capacity_utilization for r in healthy_replicas) + return total_utilization / len(healthy_replicas) + + def get_sessions_per_replica(self) -> float: + """Get average sessions per replica.""" + if self.total_replicas == 0: + return 0.0 + return self.total_sessions / self.total_replicas diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py index 1fea467b3..9bb952679 100644 --- a/src/forge/controller/replica.py +++ b/src/forge/controller/replica.py @@ -199,22 +199,23 @@ async def _do_recovery(): if old_proc_mesh is not None: try: await old_proc_mesh.stop() - logger.debug("Old proc_mesh stopped for replica %d", self.idx) + logger.debug(f"Old proc_mesh stopped for replica {self.idx}") except Exception as e: logger.warning( - "Error stopping old proc_mesh for replica %d: %s", self.idx, e + f"Error stopping old proc_mesh for replica {self.idx}: {e}" ) try: - logger.debug("Creating new proc_mesh for replica %d", self.idx) + logger.debug(f"Creating new proc_mesh for replica {self.idx}") await self.initialize() - logger.debug("Recovery completed successfully for replica %d", self.idx) + logger.debug(f"Recovery completed successfully for replica {self.idx}") except Exception as e: - logger.error("Recovery failed for replica %d: %s", self.idx, e) + logger.error(f"Recovery failed for replica {self.idx}: {e}") self.state = ReplicaState.UNHEALTHY raise - logger.debug("Starting recovery for replica %d", self.idx) + logger.debug(f"Starting recovery for replica {self.idx}") + self.state = ReplicaState.RECOVERING self._recovery_task = asyncio.create_task(_do_recovery()) await self._recovery_task @@ -223,15 +224,15 @@ async def create_proc_mesh(self): # TODO - for policy replica, we would override this method to # include multiple proc_meshes if self.proc_mesh is not None: - logger.warning("Proc mesh already initialized for replica %d", self.idx) + logger.warning(f"Proc mesh already initialized for replica {self.idx}") return - logger.debug("Creating proc_mesh for replica %d", self.idx) + logger.debug(f"Creating proc_mesh for replica {self.idx}") try: self.proc_mesh = await get_proc_mesh(process_config=self.proc_config) - logger.debug("Proc mesh created successfully for replica %d", self.idx) + logger.debug(f"Proc mesh created successfully for replica {self.idx}") except Exception as e: - logger.error("Failed to create proc_mesh for replica %d: %s", self.idx, e) + logger.error(f"Failed to create proc_mesh for replica {self.idx}: {e}") self.state = ReplicaState.UNHEALTHY raise @@ -262,10 +263,10 @@ async def spawn_actor(self, actor_def, *actor_args, **actor_kwargs): if setup_method := getattr(self.actor, "setup", None): await setup_method.call() - logger.debug("Actor spawned successfully on replica %d", self.idx) + logger.debug(f"Actor spawned successfully on replica {self.idx}") except Exception as e: - logger.error("Failed to spawn actor on replica %d: %s", self.idx, e) + logger.error(f"Failed to spawn actor on replica {self.idx}: {e}") self.mark_failed() raise @@ -275,7 +276,7 @@ def start_processing(self): """Start the replica's processing loop if not already running.""" if self._run_task is None or self._run_task.done(): self._run_task = asyncio.create_task(self.run()) - logger.debug("Started processing loop for replica %d", self.idx) + logger.debug(f"Started processing loop for replica {self.idx}") async def enqueue_request(self, request: ServiceRequest): """Enqueues a request for processing by this replica.""" @@ -317,7 +318,7 @@ async def _process_single_request(self, request: ServiceRequest) -> bool: result = result._values[0] request.future.set_result(result) except ActorError as e: - logger.warning("Got failure on replica %d. Error:\n%s", self.idx, e) + logger.warning(f"Got failure on replica {self.idx}. Error:\n{e}") # The exception came from the actor. It itself is # returned to be propagated through the services # back to the caller. @@ -329,9 +330,7 @@ async def _process_single_request(self, request: ServiceRequest) -> bool: self.mark_failed() success = False except Exception as e: - logger.debug( - "Got unexpected error on replica %d. Error:\n%s", self.idx, e - ) + logger.debug(f"Got unexpected error on replica {self.idx}. Error:\n{e}") self.mark_failed() # The exception was not from the actor - in this case @@ -377,17 +376,13 @@ async def run(self): continue except Exception as e: - logger.error( - "Error in replica %d processing loop: %s", - self.idx, - e, - ) + logger.error(f"Error in replica {self.idx} processing loop: {e}") self.state = ReplicaState.UNHEALTHY break finally: self._running = False - logger.debug("Replica %d stopped processing", self.idx) + logger.debug(f"Replica {self.idx} stopped processing") # Replica state management @@ -418,7 +413,7 @@ def failed(self) -> bool: def mark_failed(self): """Mark the replica as failed, triggering recovery.""" - logger.debug("Marking replica %d as failed", self.idx) + logger.debug(f"Marking replica {self.idx} as failed") self.state = ReplicaState.RECOVERING async def stop(self): diff --git a/src/forge/controller/service.py b/src/forge/controller/service.py index 6be098639..f59268b0e 100644 --- a/src/forge/controller/service.py +++ b/src/forge/controller/service.py @@ -33,140 +33,28 @@ """ import asyncio -import contextvars import logging import pprint import uuid -from dataclasses import dataclass, field from typing import Dict, List -from monarch._src.actor.endpoint import EndpointProperty +from monarch.actor import Actor, endpoint -from forge.controller.replica import Replica, ReplicaMetrics, ServiceRequest +from forge.controller.interface import _session_context, Session +from forge.controller.metrics import ServiceMetrics +from forge.controller.replica import Replica, ServiceRequest from forge.types import ServiceConfig logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -# TODO - tie this into metrics logger when it exists. -@dataclass -class ServiceMetrics: - """ - Aggregated metrics collection for the entire service. - - Provides service-wide visibility into performance, health, and scaling metrics - by aggregating data from all replica instances. - - Attributes: - replica_metrics: Per-replica metrics indexed by replica ID - total_sessions: Number of active sessions across all replicas - healthy_replicas: Number of currently healthy replicas - total_replicas: Total number of replicas (healthy + unhealthy) - last_scale_event: Timestamp of the last scaling operation - """ - - # Replica metrics - replica_metrics: Dict[int, ReplicaMetrics] = field(default_factory=dict) - # Service-level metrics - total_sessions: int = 0 - healthy_replicas: int = 0 - total_replicas: int = 0 - # Time-based metrics - last_scale_event: float = 0.0 - - def get_total_request_rate(self, window_seconds: float = 60.0) -> float: - """Get total requests per second across all replicas.""" - return sum( - metrics.get_request_rate(window_seconds) - for metrics in self.replica_metrics.values() - ) - - def get_avg_queue_depth(self, replicas: List) -> float: - """Get average queue depth across all healthy replicas.""" - healthy_replicas = [r for r in replicas if r.healthy] - if not healthy_replicas: - return 0.0 - total_queue_depth = sum(r.qsize() for r in healthy_replicas) - return total_queue_depth / len(healthy_replicas) - - def get_avg_capacity_utilization(self, replicas: List) -> float: - """Get average capacity utilization across all healthy replicas.""" - healthy_replicas = [r for r in replicas if r.healthy] - if not healthy_replicas: - return 0.0 - total_utilization = sum(r.capacity_utilization for r in healthy_replicas) - return total_utilization / len(healthy_replicas) - - def get_sessions_per_replica(self) -> float: - """Get average sessions per replica.""" - if self.total_replicas == 0: - return 0.0 - return self.total_sessions / self.total_replicas - - -# Context variable for session state -_session_context = contextvars.ContextVar("session_context") - - -@dataclass -class Session: - """Simple session data holder.""" - - session_id: str - - -class SessionContext: - """ - Async context manager for stateful service sessions with automatic lifecycle management. - - Provides a convenient way to maintain stateful connections to replicas across multiple - requests. Sessions ensure that all requests within the context are routed to the same - replica, enabling stateful interactions while handling session lifecycle automatically. - - Example: - - >>> async with service.session() as session: - ... # All calls within this block use the same replica - ... result1 = await service.my_endpoint(arg1) - ... result2 = await service.another_endpoint(result1) - - """ - - def __init__(self, service: "Service"): - self.service = service - self.session_id: str | None = None - self._token = None - - async def __aenter__(self): - """Start a session and set context variables.""" - self.session_id = await self.service.start_session() - # Set context for this async task - context_value = {"session_id": self.session_id} - self._token = _session_context.set(context_value) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Terminate the session and restore context.""" - if self._token: - _session_context.reset(self._token) - if self.session_id: - await self.service.terminate_session(self.session_id) - self.session_id = None - - -class Service: +class Service(Actor): """ Distributed Actor Service Controller - A sophisticated service orchestration system that manages multiple replicas of actor-based - services with automatic scaling, fault tolerance, and intelligent load balancing. - - The Service acts as a unified interface for distributed workloads, automatically handling: - - **Fault Tolerance**: Health monitoring, automatic replica recovery, request migration - - **Load Balancing**: Round-robin, least-loaded, and session-affinity routing - - **Session Management**: Stateful session handling with context propagation - - **Metrics Collection**: Comprehensive performance and health monitoring + A service orchestration system that manages multiple replicas of actor-based + services with fault tolerance and load balancing. Args: cfg: Service configuration including number of replicas, GPUs per replica, and health polling rate @@ -174,26 +62,6 @@ class Service: *actor_args: Positional arguments passed to actor constructor **actor_kwargs: Keyword arguments passed to actor constructor - Example: - Basic setup with autoscaling: - - >>> config = ServiceConfig( - ... gpus_per_replica=1, - ... num_replicas=3, - ... ) - >>> service = Service(config, MyActorClass, model_path="/path/to/model") - >>> await service.__initialize__() - - Session-based usage: - - >>> async with service.session(): - ... result1 = await service.my_endpoint(arg1, arg2) - ... result2 = await service.another_endpoint(arg3) - - Stateless usage: - - >>> result = await service.my_endpoint(arg1, arg2) # Uses round-robin - Attributes: _cfg: Service configuration _replicas: List of managed replica instances @@ -202,7 +70,9 @@ class Service: _endpoints: Dynamically registered actor endpoints """ - def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): + def __init__( + self, cfg: ServiceConfig, actor_def, actor_args: tuple, actor_kwargs: dict + ): self._cfg = cfg self._replicas = [] self._actor_def = actor_def @@ -222,16 +92,7 @@ def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): # Replica initialization queue self._replicas_to_recover = [] - # For all endpoints within the actor_def, create an interface from it - self._endpoints = [] - for func_name in dir(actor_def): - func = getattr(actor_def, func_name) - if isinstance(func, EndpointProperty): - logger.debug(f"Registering endpoint {func_name}") - self._endpoints.append(func_name) - # Dynamically add this endpoint method to the Service class - self._add_endpoint_method(func_name) - + @endpoint async def __initialize__(self): """Initializes the service and starts the health loop.""" logger.debug(f"Starting service up with {self._cfg.num_replicas} replicas.") @@ -262,14 +123,9 @@ async def __initialize__(self): self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) - def _add_endpoint_method(self, endpoint_name: str): - """Dynamically adds an endpoint method to this Service instance.""" - - async def endpoint_method(sess_id: str | None = None, *args, **kwargs): - return await self._call(sess_id, endpoint_name, *args, **kwargs) - - # Set the method on this instance - setattr(self, endpoint_name, endpoint_method) + @endpoint + async def call(self, sess_id: str | None, function: str, *args, **kwargs): + return await self._call(sess_id, function, *args, **kwargs) async def _call(self, sess_id: str | None, function: str, *args, **kwargs): """ @@ -322,15 +178,65 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): # If the replica failed, try to retry once if not replica.healthy: logger.debug( - "Replica %d failed during request, retrying on healthy replica. Exception: %s", - replica.idx, - e, + f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" ) return await self._retry_request_on_healthy_replica( sess_id, function, *args, **kwargs ) raise + @endpoint + async def call_all(self, function: str, *args, **kwargs) -> List: + """ + Broadcasts a function call to all healthy replicas and returns results as a list. + + Args: + function: Name of the actor endpoint to call + *args: Positional arguments to pass to the endpoint + **kwargs: Keyword arguments to pass to the endpoint + + Returns: + List of results from all healthy replicas + + Raises: + RuntimeError: If no healthy replicas are available + """ + healthy_replicas = [r for r in self._replicas if r.healthy] + + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for broadcast call") + + # Create requests for all healthy replicas + requests = [] + for replica in healthy_replicas: + request = ServiceRequest( + session_id=None, # Broadcast calls don't use sessions + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) + requests.append((replica, request)) + + # Enqueue all requests + for replica, request in requests: + await replica.enqueue_request(request) + + # Wait for all results + results = [] + for replica, request in requests: + try: + result = await request.future + results.append(result) + except Exception as e: + logger.warning( + f"Request to replica {replica.idx} failed during broadcast: {e}" + ) + # Add None for failed replicas to maintain indexing + results.append(None) + + return results + async def _retry_request_on_healthy_replica( self, sess_id: str | None, function: str, *args, **kwargs ): @@ -389,6 +295,7 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): ): self._session_replica_map[sess_id] = target_replica.idx + @endpoint async def start_session(self) -> str: """ Starts a new session for stateful request handling. @@ -414,10 +321,6 @@ async def start_session(self) -> str: return sess_id - def session(self) -> SessionContext: - """Returns a context manager for session-based calls.""" - return SessionContext(self) - def _update_service_metrics(self): """Updates service-level metrics.""" self._metrics.total_sessions = len(self._active_sessions) @@ -429,6 +332,7 @@ def _update_service_metrics(self): # Use the replica's own metrics directly self._metrics.replica_metrics[replica.idx] = replica.metrics + @endpoint def get_metrics(self) -> ServiceMetrics: """ Get comprehensive service metrics for monitoring and analysis. @@ -447,6 +351,7 @@ def get_metrics(self) -> ServiceMetrics: self._update_service_metrics() return self._metrics + @endpoint def get_metrics_summary(self) -> dict: """ Get a summary of key metrics for monitoring and debugging. @@ -504,6 +409,7 @@ def get_metrics_summary(self) -> dict: return summary + @endpoint async def terminate_session(self, sess_id: str): """ Terminates an active session and cleans up associated resources. @@ -604,6 +510,7 @@ async def _get_replica(self, sess_id: str | None) -> "Replica": logger.debug("Assigning session %s to replica %d", sess_id, replica.idx) return replica + @endpoint async def stop(self): logger.debug("Stopping service...") # Signal shutdown to health loop @@ -670,5 +577,45 @@ async def _migrate_replica_workload(self, replica_to_remove: Replica): del self._session_replica_map[sess_id] logger.debug("Session %s will be reassigned on next request", sess_id) + @endpoint + def _get_internal_state(self) -> dict: + """ + Gets comprehensive internal state for testing purposes. + + This is intended for testing/debugging only, it should not + be relied upon in actual production code. + """ + # Ensure metrics are up to date + self._update_service_metrics() + + return { + # Session management state + "session_replica_map": dict(self._session_replica_map), # Copy for safety + "active_sessions": [s.session_id for s in self._active_sessions], + "id_session_map": dict(self._id_session_map), # Copy for safety + # Replica state + "replicas": [ + { + "idx": replica.idx, + "state": replica.state.value, + "healthy": replica.healthy, + "failed": replica.failed, + "active_requests": replica.active_requests, + "queue_size": replica.request_queue.qsize(), + "capacity_utilization": replica.capacity_utilization, + } + for replica in self._replicas + ], + # Load balancing state + "next_replica_idx": self._next_replica_idx, + # Service-level state + "total_replicas": len(self._replicas), + "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), + "shutdown_requested": self._shutdown_requested, + # Metrics summary + "total_sessions": len(self._active_sessions), + "replica_count": len(self._replicas), + } + def __repr__(self): return f"Service(actor={self._actor_def.__name__})" diff --git a/src/forge/controller/spawn.py b/src/forge/controller/spawn.py index fe0512277..5044fed7d 100644 --- a/src/forge/controller/spawn.py +++ b/src/forge/controller/spawn.py @@ -8,9 +8,10 @@ import logging from typing import Type -from monarch.actor import Actor +from monarch.actor import Actor, proc_mesh from forge.controller import Service, ServiceConfig +from forge.controller.interface import ServiceInterface logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -18,7 +19,7 @@ async def spawn_service( service_cfg: ServiceConfig, actor_def: Type[Actor], *actor_args, **actor_kwargs -) -> Service: +) -> ServiceInterface: """Spawns a service based on the actor class. Args: @@ -28,10 +29,15 @@ async def spawn_service( **actor_kwargs: Keyword arguments to pass to actor constructor Returns: - The appropriate service type based on the actor class + A ServiceInterface that provides access to the Service Actor """ - # Default to base Service - logger.info("Spawning base Service for %s", actor_def.__name__) - service = Service(service_cfg, actor_def, *actor_args, **actor_kwargs) - await service.__initialize__() - return service + # Create a single-node proc_mesh and actor_mesh for the Service Actor + logger.info("Spawning Service Actor for %s", actor_def.__name__) + m = await proc_mesh(gpus=1) + service_actor = await m.spawn( + "service", Service, service_cfg, actor_def, actor_args, actor_kwargs + ) + await service_actor.__initialize__.call_one() + + # Return the ServiceInterface that wraps the proc_mesh, actor_mesh, and actor_def + return ServiceInterface(m, service_actor, actor_def) diff --git a/tests/test_service.py b/tests/test_service.py index b49818cfc..4b7f4a738 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -47,6 +47,13 @@ async def slow_incr(self): await asyncio.sleep(1.0) self.v += 1 + @endpoint + async def add_to_value(self, amount: int, multiplier: int = 1) -> int: + """Add an amount (optionally multiplied) to the current value.""" + logger.info(f"adding {amount} with {multiplier}") + self.v += amount * multiplier + return self.v + # Core Functionality Tests @@ -66,16 +73,18 @@ async def test_basic_service_operations(): assert isinstance(session1, str) # Test endpoint calls - await service.incr(session1) - result = await service.value(session1) + await service.incr.choose(sess_id=session1) + result = await service.value.choose(sess_id=session1) assert result == 1 # Test session mapping - assert session1 in service._session_replica_map + state = await service._get_internal_state() + assert session1 in state["session_replica_map"] # Test session termination await service.terminate_session(session1) - assert session1 not in service._session_replica_map + state = await service._get_internal_state() + assert session1 not in state["session_replica_map"] finally: await service.stop() @@ -84,29 +93,34 @@ async def test_basic_service_operations(): @pytest.mark.timeout(10) @pytest.mark.asyncio async def test_sessionless_calls(): - """Test sessionless calls with round-robin load balancing.""" + """Test sessionless calls with round robin load balancing.""" cfg = ServiceConfig(procs_per_replica=1, num_replicas=2) service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) try: # Test sessionless calls - await service.incr() - await service.incr() - result = await service.value() + await service.incr.choose() + await service.incr.choose() + result = await service.value.choose() assert result is not None # No sessions should be created - assert len(service._active_sessions) == 0 - assert len(service._session_replica_map) == 0 + state = await service._get_internal_state() + assert len(state["active_sessions"]) == 0 + assert len(state["session_replica_map"]) == 0 # Verify load distribution - metrics = service.get_metrics_summary() + metrics = await service.get_metrics_summary() total_requests = sum( replica_metrics["total_requests"] for replica_metrics in metrics["replicas"].values() ) assert total_requests == 3 + # Users should be able to call endpoint with just args + result = await service.add_to_value.choose(5, multiplier=2) + assert result == 11 # 1 + 10 + finally: await service.stop() @@ -121,18 +135,18 @@ async def test_session_context_manager(): try: # Test context manager usage async with service.session(): - await service.incr() - await service.incr() - result = await service.value() + await service.incr.choose() + await service.incr.choose() + result = await service.value.choose() assert result == 2 # Test sequential context managers to avoid interference async def worker(increments: int): async with service.session(): - initial = await service.value() + initial = await service.value.choose() for _ in range(increments): - await service.incr() - final = await service.value() + await service.incr.choose() + final = await service.value.choose() return final - initial # Run sessions sequentially to avoid concurrent modification @@ -142,8 +156,9 @@ async def worker(increments: int): assert sorted(results) == [2, 3] # Test that context manager properly manages session lifecycle - assert len(service._active_sessions) == 0 - assert len(service._session_replica_map) == 0 + state = await service._get_internal_state() + assert len(state["active_sessions"]) == 0 + assert len(state["session_replica_map"]) == 0 finally: await service.stop() @@ -161,25 +176,28 @@ async def test_recovery_state_transitions(): try: # Initially replica should be healthy - replica = service._replicas[0] - assert replica.state.value == "HEALTHY" - assert replica.healthy is True - assert replica.failed is False + state = await service._get_internal_state() + replica_state = state["replicas"][0] + assert replica_state["state"] == "HEALTHY" + assert replica_state["healthy"] is True + assert replica_state["failed"] is False # Create session and make a successful call session = await service.start_session() - await service.incr(session) - result = await service.value(session) + await service.incr.choose(sess_id=session) + result = await service.value.choose(sess_id=session) assert result == 1 # Cause failure - this should transition to RECOVERING - error_result = await service.fail_me(session) + error_result = await service.fail_me.choose(sess_id=session) assert isinstance(error_result, RuntimeError) # Replica should now be in RECOVERING state - assert replica.state.value == "RECOVERING" - assert replica.healthy is False - assert replica.failed is True + state = await service._get_internal_state() + replica_state = state["replicas"][0] + assert replica_state["state"] == "RECOVERING" + assert replica_state["healthy"] is False + assert replica_state["failed"] is True # Wait for health loop to detect and attempt recovery # The health loop runs every 0.1s, so give it some time @@ -192,35 +210,39 @@ async def test_recovery_state_transitions(): await asyncio.sleep(wait_interval) elapsed += wait_interval - if replica.state.value in ["HEALTHY", "UNHEALTHY"]: + state = await service._get_internal_state() + replica_state = state["replicas"][0] + if replica_state["state"] in ["HEALTHY", "UNHEALTHY"]: break # After recovery, replica should be healthy again # (unless recovery failed, in which case it would be UNHEALTHY) - assert replica.state.value in ["HEALTHY", "UNHEALTHY"] + state = await service._get_internal_state() + replica_state = state["replicas"][0] + assert replica_state["state"] in ["HEALTHY", "UNHEALTHY"] - if replica.state.value == "HEALTHY": + if replica_state["state"] == "HEALTHY": # If recovery succeeded, verify we can make calls again - assert replica.healthy is True - assert replica.failed is False + assert replica_state["healthy"] is True + assert replica_state["failed"] is False # Test that we can make new calls after recovery new_session = await service.start_session() - await service.incr(new_session) - result = await service.value(new_session) + await service.incr.choose(sess_id=new_session) + result = await service.value.choose(sess_id=new_session) assert ( result is not None ) # Should get a result (counter starts at 0 in new actor) - elif replica.state.value == "UNHEALTHY": + elif replica_state["state"] == "UNHEALTHY": # If recovery failed, verify failed state - assert replica.healthy is False - assert replica.failed is True + assert replica_state["healthy"] is False + assert replica_state["failed"] is True # Verify that the state transition path was correct # (We can't guarantee the exact end state due to potential flakiness in test environments, # but we can verify the replica went through the expected transition) - logger.info(f"Final replica state: {replica.state.value}") + logger.info(f"Final replica state: {replica_state['state']}") finally: await service.stop() @@ -236,28 +258,32 @@ async def test_replica_failure_and_recovery(): try: # Create session and cause failure session = await service.start_session() - await service.incr(session) + await service.incr.choose(sess_id=session) - original_replica_idx = service._session_replica_map[session] + state = await service._get_internal_state() + original_replica_idx = state["session_replica_map"][session] # Cause failure - error_result = await service.fail_me(session) + error_result = await service.fail_me.choose(sess_id=session) assert isinstance(error_result, RuntimeError) # Replica should be marked as failed - failed_replica = service._replicas[original_replica_idx] - assert not failed_replica.healthy + state = await service._get_internal_state() + failed_replica = state["replicas"][original_replica_idx] + assert not failed_replica["healthy"] # Session should be reassigned on next call - await service.incr(session) - new_replica_idx = service._session_replica_map[session] + await service.incr.choose(sess_id=session) + state = await service._get_internal_state() + new_replica_idx = state["session_replica_map"][session] assert new_replica_idx != original_replica_idx # New sessions should avoid failed replica new_session = await service.start_session() - await service.incr(new_session) - assigned_replica = service._replicas[service._session_replica_map[new_session]] - assert assigned_replica.healthy + await service.incr.choose(sess_id=new_session) + state = await service._get_internal_state() + assigned_replica = state["replicas"][state["session_replica_map"][new_session]] + assert assigned_replica["healthy"] finally: await service.stop() @@ -278,17 +304,17 @@ async def test_metrics_collection(): session1 = await service.start_session() session2 = await service.start_session() - await service.incr(session1) - await service.incr(session1) - await service.incr(session2) + await service.incr.choose(sess_id=session1) + await service.incr.choose(sess_id=session1) + await service.incr.choose(sess_id=session2) # Test failure metrics - error_result = await service.fail_me(session1) + error_result = await service.fail_me.choose(sess_id=session1) assert isinstance(error_result, RuntimeError) # Get metrics - metrics = service.get_metrics() - summary = service.get_metrics_summary() + metrics = await service.get_metrics() + summary = await service.get_metrics_summary() # Test service-level metrics assert metrics.total_sessions == 2 @@ -330,18 +356,20 @@ async def test_session_stickiness(): session = await service.start_session() # Make multiple calls - await service.incr(session) - await service.incr(session) - await service.incr(session) + await service.incr.choose(sess_id=session) + await service.incr.choose(sess_id=session) + await service.incr.choose(sess_id=session) # Should always route to same replica - replica_idx = service._session_replica_map[session] + state = await service._get_internal_state() + replica_idx = state["session_replica_map"][session] - await service.incr(session) - assert service._session_replica_map[session] == replica_idx + await service.incr.choose(sess_id=session) + state = await service._get_internal_state() + assert state["session_replica_map"][session] == replica_idx # Verify counter was incremented correctly - result = await service.value(session) + result = await service.value.choose(sess_id=session) assert result == 4 finally: @@ -358,20 +386,25 @@ async def test_load_balancing_multiple_sessions(): try: # Create sessions with some load to trigger distribution session1 = await service.start_session() - await service.incr(session1) # Load replica 0 + await service.incr.choose(sess_id=session1) # Load replica 0 session2 = await service.start_session() - await service.incr(session2) # Should go to replica 1 (least loaded) + await service.incr.choose( + sess_id=session2 + ) # Should go to replica 1 (least loaded) session3 = await service.start_session() - await service.incr(session3) # Should go to replica 0 or 1 based on load + await service.incr.choose( + sess_id=session3 + ) # Should go to replica 0 or 1 based on load session4 = await service.start_session() - await service.incr(session4) # Should balance the load + await service.incr.choose(sess_id=session4) # Should balance the load # Check that sessions are distributed (may not be perfectly even due to least-loaded logic) + state = await service._get_internal_state() replica_assignments = [ - service._session_replica_map[s] + state["session_replica_map"][s] for s in [session1, session2, session3, session4] ] unique_replicas = set(replica_assignments) @@ -381,7 +414,7 @@ async def test_load_balancing_multiple_sessions(): assert len(unique_replicas) >= 1 # At least one replica used # Verify that load balancing is working by checking request distribution - metrics = service.get_metrics_summary() + metrics = await service.get_metrics_summary() total_requests = sum( replica_metrics["total_requests"] for replica_metrics in metrics["replicas"].values() @@ -407,20 +440,21 @@ async def test_concurrent_operations(): # Concurrent operations tasks = [ - service.incr(session), # Session call - service.incr(session), # Session call - service.incr(), # Sessionless call - service.incr(), # Sessionless call + service.incr.choose(sess_id=session), # Session call + service.incr.choose(sess_id=session), # Session call + service.incr.choose(), # Sessionless call + service.incr.choose(), # Sessionless call ] await asyncio.gather(*tasks) # Verify session tracking - assert len(service._active_sessions) == 1 - assert session in service._session_replica_map + state = await service._get_internal_state() + assert len(state["active_sessions"]) == 1 + assert session in state["session_replica_map"] # Verify total requests - metrics = service.get_metrics_summary() + metrics = await service.get_metrics_summary() total_requests = sum( replica_metrics["total_requests"] for replica_metrics in metrics["replicas"].values() @@ -429,3 +463,112 @@ async def test_concurrent_operations(): finally: await service.stop() + + +# `call` endpoint tests + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_broadcast_call_basic(): + """Test basic broadcast call functionality.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=10) + + try: + # Test broadcast call to all replicas + results = await service.incr.call() + + # Should get results from all healthy replicas + assert isinstance(results, list) + assert len(results) == 3 # All 3 replicas should respond + + # All results should be None (incr doesn't return anything) + assert all(result is None for result in results) + + # Test getting values from all replicas + values = await service.value.call() + assert isinstance(values, list) + assert len(values) == 3 + + # All replicas should have incremented from 10 to 11 + assert all(value == 11 for value in values) + + finally: + await service.stop() + + +@pytest.mark.timeout(15) +@pytest.mark.asyncio +async def test_broadcast_call_with_failed_replica(): + """Test broadcast call behavior when some replicas fail.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) + + try: + # First, cause one replica to fail by calling fail_me on a specific session + session = await service.start_session() + try: + await service.fail_me.choose(sess_id=session) + except RuntimeError: + pass # Expected failure + + # Wait briefly for replica to be marked as failed + await asyncio.sleep(0.1) + + # Now test broadcast call - should only hit healthy replicas + results = await service.incr.call() + + # Should get results from healthy replicas only + assert isinstance(results, list) + # Results length should match number of healthy replicas (2 out of 3) + state = await service._get_internal_state() + healthy_count = sum(1 for r in state["replicas"] if r["healthy"]) + assert len(results) == healthy_count + + # Get values from all healthy replicas + values = await service.value.call() + assert len(values) == healthy_count + + # All healthy replicas should have incremented to 1 + assert all(value == 1 for value in values) + + finally: + await service.stop() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_broadcast_call_vs_choose(): + """Test that broadcast call hits all replicas while choose hits only one.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) + + try: + # Use broadcast call to increment all replicas + await service.incr.call() + + # Get values from all replicas + values_after_broadcast = await service.value.call() + assert len(values_after_broadcast) == 3 + assert all(value == 1 for value in values_after_broadcast) + + # Use choose to increment only one replica + await service.incr.choose() + + # Get values again - one replica should be at 2, others at 1 + values_after_choose = await service.value.call() + assert len(values_after_choose) == 3 + assert sorted(values_after_choose) == [1, 1, 2] # One replica incremented twice + + # Verify metrics show the correct number of requests + metrics = await service.get_metrics_summary() + total_requests = sum( + replica_metrics["total_requests"] + for replica_metrics in metrics["replicas"].values() + ) + # incr.call() (3 requests) + value.call() (3 requests) + incr.choose() (1 request) + value.call() (3 requests) = 10 total + assert total_requests == 10 + + finally: + await service.stop() From 922b492f024aedb27b5f01649ee1b234c43f7646 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Wed, 20 Aug 2025 11:38:47 -0700 Subject: [PATCH 10/15] Pushing Policy Worker:rough --- src/forge/actors/policy.py | 96 +++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 43 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 6866a6949..c096a1657 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -14,6 +14,7 @@ import torch +from forge.interfaces import Policy as PolicyInterface from monarch.actor import Actor, current_rank, endpoint, proc_mesh from torchstore import MultiProcessStore @@ -45,18 +46,21 @@ @dataclass -class PolicyRouter(Actor): +class Policy(PolicyInterface): # TODO: Add dp support - policy: Actor + policy_worker: Actor = None sampling_params: SamplingParams = None lora_request: LoRARequest = None tokenization_kwargs: dict = None @endpoint - async def setup(self): + async def setup(self, config, guided_decoding=False, num_samples=1): + # Set up workers + await self.setupWorker(config, guided_decoding, num_samples) + self.request_id = 0 self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {} - self.vllm_args = await self.policy.get_vllm_args.choose() + self.vllm_args = await self.policy_worker.get_vllm_args.choose() # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` @@ -72,7 +76,7 @@ async def setup(self): # Setup schduuler # TODO: Add support for `log_stats` - kv_cache_configs = await self.policy.setup_kv_cache.call() + kv_cache_configs = await self.policy_worker.setup_kv_cache.call() kv_cache_config = kv_cache_configs._values[0] self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks self.vllm_args.cache_config.num_cpu_blocks = 0 @@ -86,6 +90,34 @@ async def setup(self): log_stats=None, ) + async def setupWorker(self, config, guided_decoding, num_samples): + self.worker_mesh = await proc_mesh( + gpus=config["resources"], + env={ + "MASTER_ADDR": str(get_loopback_ip()), + "MASTER_PORT": str(get_open_port()), + }, + ) + self.policy_worker = await self.worker_mesh.spawn( + "policy_worker", PolicyWorker, **config + ) + + # TODO: Make this customizable from the config + vllm_args = await self.policy_worker.get_vllm_args.choose() + overrides = { + "n": num_samples, + "guided_decoding": ( + GuidedDecodingParams(choice=["Positive", "Negative"]) + if guided_decoding + else None + ), + } + self.sampling_params = get_default_sampling_params( + vllm_args, overrides=overrides + ) + + await self.policy_worker.setup.call() + @endpoint async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: self.request_id += 1 % sys.maxsize @@ -169,7 +201,9 @@ async def run(self): self.running = True while self.running: scheduler_output = self.scheduler.schedule() - worker_outputs = await self.policy.execute_model.call(scheduler_output) + worker_outputs = await self.policy_worker.execute_model.call( + scheduler_output + ) worker_output = worker_outputs._values[output_rank] outputs = self.scheduler.update_from_output(scheduler_output, worker_output) outputs = outputs.get(0) or EngineCoreOutputs() @@ -185,13 +219,18 @@ async def run(self): _, fut = self.requests.pop(request_output.request_id) fut.set_result(request_output.outputs) + @endpoint + async def update_weights(self): + """Update the policy weights.""" + pass + @endpoint async def shutdown(self): self.running = False @dataclass -class Policy(Actor): +class PolicyWorker(Actor): model: str tensor_parallel_size: int = 1 pipeline_parallel_size: int = 1 @@ -389,46 +428,17 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: async def _test(config, guided_decoding=False, num_samples=1): # TODO: Create proper test - router_mesh = await proc_mesh(gpus=1) - policy_mesh = await proc_mesh( - gpus=config["resources"], - env={ - "MASTER_ADDR": str(get_loopback_ip()), - "MASTER_PORT": str(get_open_port()), - }, - ) - - policy_actor = await policy_mesh.spawn("policy", Policy, **config) - - # TODO: Make this customizable from the config - overrides = { - "n": num_samples, - "guided_decoding": ( - GuidedDecodingParams(choice=["Positive", "Negative"]) - if guided_decoding - else None - ), - } - - vllm_args = await policy_actor.get_vllm_args.choose() - sampling_params = get_default_sampling_params(vllm_args, overrides=overrides) - - router = await router_mesh.spawn( - "policy_router", - PolicyRouter, - policy=policy_actor, - sampling_params=sampling_params, - ) + policy_mesh = await proc_mesh(gpus=1) + policy = await policy_mesh.spawn("policy", Policy) - await policy_actor.setup.call() - await router.setup.call() print("Model setup") + await policy.setup.call(config, guided_decoding, num_samples) - router.run.call() print("Model running") + policy.run.call() prompt = "What is 3+5?" if guided_decoding else "Tell me a joke" - responses: List[CompletionOutput] = await router.generate.call_one(prompt) + responses: List[CompletionOutput] = await policy.generate.call_one(prompt) for batch, response in enumerate(responses): print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") @@ -436,7 +446,7 @@ async def _test(config, guided_decoding=False, num_samples=1): print(f"User: {prompt}\nAssistant: {response.text}") print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - await router.shutdown.call() + await policy.shutdown.call() if __name__ == "__main__": @@ -450,4 +460,4 @@ async def _test(config, guided_decoding=False, num_samples=1): # asyncio.run(_test(config)) # asyncio.run(_test(config, guided_decoding=True)) # asyncio.run(_test(config, num_samples=2)) - # asyncio.run(_test(config, guided_decoding=True, num_samples=3)) + asyncio.run(_test(config, guided_decoding=True, num_samples=3)) From ca9ba0512aa53e3aaad68ee17e7c6dd8dacad154 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 21 Aug 2025 10:10:28 -0700 Subject: [PATCH 11/15] Partial paths for sapwn_actor and spawn service --- src/forge/actors/policy.py | 131 ++++++++++++++++++++++--------------- src/forge/interfaces.py | 14 ++++ 2 files changed, 94 insertions(+), 51 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index c096a1657..74bb05b72 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,9 +13,16 @@ from typing import Dict, List import torch +from forge.controller import spawn_actors +from forge.controller.service import ServiceConfig +from forge.controller.spawn import spawn_service +from forge.data.sharding import VLLMSharding from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig + from monarch.actor import Actor, current_rank, endpoint, proc_mesh +from omegaconf import DictConfig, OmegaConf from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM @@ -40,27 +47,43 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.data.sharding import VLLMSharding - logger = logging.getLogger(__name__) @dataclass class Policy(PolicyInterface): # TODO: Add dp support + config: DictConfig + # Gets set up by setup worker policy_worker: Actor = None + sampling_params: SamplingParams = None lora_request: LoRARequest = None tokenization_kwargs: dict = None @endpoint - async def setup(self, config, guided_decoding=False, num_samples=1): - # Set up workers - await self.setupWorker(config, guided_decoding, num_samples) + async def setup(self): + # Set up policy_worker + await self.spawn_workers() self.request_id = 0 self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {} self.vllm_args = await self.policy_worker.get_vllm_args.choose() + + # Setup sampling params + sampling_overrides = self.config.sampling_params + overrides = { + "n": sampling_overrides.num_samples, + "guided_decoding": ( + GuidedDecodingParams(choice=["Positive", "Negative"]) + if sampling_overrides.guided_decoding + else None + ), + } + self.sampling_params = get_default_sampling_params( + self.vllm_args, overrides=overrides + ) + # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` @@ -74,7 +97,7 @@ async def setup(self, config, guided_decoding=False, num_samples=1): ) self.output_processor = OutputProcessor(tokenizer, log_stats=None) - # Setup schduuler + # Setup scheduler # TODO: Add support for `log_stats` kv_cache_configs = await self.policy_worker.setup_kv_cache.call() kv_cache_config = kv_cache_configs._values[0] @@ -90,32 +113,20 @@ async def setup(self, config, guided_decoding=False, num_samples=1): log_stats=None, ) - async def setupWorker(self, config, guided_decoding, num_samples): + def should_spawn_workers(self) -> bool: + return True + + async def spawn_workers(self): self.worker_mesh = await proc_mesh( - gpus=config["resources"], + gpus=self.config.num_workers, env={ "MASTER_ADDR": str(get_loopback_ip()), "MASTER_PORT": str(get_open_port()), }, ) self.policy_worker = await self.worker_mesh.spawn( - "policy_worker", PolicyWorker, **config - ) - - # TODO: Make this customizable from the config - vllm_args = await self.policy_worker.get_vllm_args.choose() - overrides = { - "n": num_samples, - "guided_decoding": ( - GuidedDecodingParams(choice=["Positive", "Negative"]) - if guided_decoding - else None - ), - } - self.sampling_params = get_default_sampling_params( - vllm_args, overrides=overrides + "policy_worker", PolicyWorker, **self.config.worker_params ) - await self.policy_worker.setup.call() @endpoint @@ -125,8 +136,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu # Wraps prompt into a dict prompt: Dict[str, str] = convert_input(prompt) - if self.sampling_params is None: - self.sampling_params = get_default_sampling_params(self.vllm_args) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} @@ -236,7 +245,6 @@ class PolicyWorker(Actor): pipeline_parallel_size: int = 1 enforce_eager: bool = False vllm_args: EngineArgs = None - resources: int = 1 state_dict_key: str = "model_state_dict" def __post_init__(self): @@ -279,7 +287,6 @@ def __post_init__(self): setattr(self.vllm_args, key, value) # Build Config self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) - assert self.vllm_args.parallel_config.world_size == self.resources @endpoint async def setup(self, store: MultiProcessStore = None): @@ -426,19 +433,32 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: return params -async def _test(config, guided_decoding=False, num_samples=1): - # TODO: Create proper test - policy_mesh = await proc_mesh(gpus=1) - policy = await policy_mesh.spawn("policy", Policy) +# TODO: Create proper test +async def _test(config: DictConfig): + prompt = ( + "What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke" + ) - print("Model setup") - await policy.setup.call(config, guided_decoding, num_samples) - - print("Model running") - policy.run.call() + with_service = False + if with_service: + service_config = ServiceConfig( + procs_per_replica=1, min_replicas=1, max_replicas=2, default_replicas=1 + ) + print("spawning service") + service = await spawn_service(service_config, Policy, config=config) + service.run() + responses: List[CompletionOutput] = await service.generate(prompt) - prompt = "What is 3+5?" if guided_decoding else "Tell me a joke" - responses: List[CompletionOutput] = await policy.generate.call_one(prompt) + else: + process_config = ProcessConfig() + policy = await spawn_actors( + name="policy", actor_cls=Policy, cfg=config, processes=process_config + ) + print("Model setup") + await policy.setup.call() + print("Model running") + policy.run.call() + responses: List[CompletionOutput] = await policy.generate.call_one(prompt) for batch, response in enumerate(responses): print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") @@ -446,18 +466,27 @@ async def _test(config, guided_decoding=False, num_samples=1): print(f"User: {prompt}\nAssistant: {response.text}") print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - await policy.shutdown.call() + if with_service: + await service.stop() + else: + await policy.shutdown.call() if __name__ == "__main__": - config = { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 1, - "enforce_eager": True, - "resources": 2, - } - # asyncio.run(_test(config)) - # asyncio.run(_test(config, guided_decoding=True)) - # asyncio.run(_test(config, num_samples=2)) - asyncio.run(_test(config, guided_decoding=True, num_samples=3)) + config = OmegaConf.create( + { + "num_workers": 2, + "worker_params": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 1, + "enforce_eager": True, + "vllm_args": None, + }, + "sampling_params": { + "guided_decoding": True, + "num_samples": 2, + }, + } + ) + asyncio.run(_test(config)) diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index b485fc791..fd8da69ad 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -87,6 +87,20 @@ async def update_weights(self): """Update the policy weights.""" pass + @abstractmethod + def should_spawn_workers(self) -> bool: + """Whether the policy needs to separately spawrn child workers.""" + pass + + @abstractmethod + def spawn_workers(self): + """ + Spawn child workers used by this actor + + No-op when should_spawn_workers() is False. + """ + pass + class BaseTokenizer(ABC): """ From cf520aada33b0999c1efd34a2aba7679fd6deb7b Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 21 Aug 2025 13:59:02 -0700 Subject: [PATCH 12/15] Commenting sections to update post Service Refactor --- src/forge/actors/policy.py | 6 ++---- src/forge/interfaces.py | 24 +++++++++++++----------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 74bb05b72..af5705c16 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -24,7 +24,6 @@ from monarch.actor import Actor, current_rank, endpoint, proc_mesh from omegaconf import DictConfig, OmegaConf from torchstore import MultiProcessStore - from torchstore._state_dict_utils import DELIM from vllm.engine.arg_utils import EngineArgs @@ -441,9 +440,8 @@ async def _test(config: DictConfig): with_service = False if with_service: - service_config = ServiceConfig( - procs_per_replica=1, min_replicas=1, max_replicas=2, default_replicas=1 - ) + # Update this condition path once Service has been refactored + service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) print("spawning service") service = await spawn_service(service_config, Policy, config=config) service.run() diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index fd8da69ad..56f77ad44 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -87,19 +87,21 @@ async def update_weights(self): """Update the policy weights.""" pass - @abstractmethod - def should_spawn_workers(self) -> bool: - """Whether the policy needs to separately spawrn child workers.""" - pass + # TODO: Update Based on Service Refactor - @abstractmethod - def spawn_workers(self): - """ - Spawn child workers used by this actor + # @abstractmethod + # def should_spawn_workers(self) -> bool: + # """Whether the policy needs to separately spawrn child workers.""" + # pass - No-op when should_spawn_workers() is False. - """ - pass + # @abstractmethod + # def spawn_workers(self): + # """ + # Spawn child workers used by this actor + + # No-op when should_spawn_workers() is False. + # """ + # pass class BaseTokenizer(ABC): From 47f3a1d5859cea53fcd0c12f3318d239fe19b292 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Wed, 20 Aug 2025 11:38:47 -0700 Subject: [PATCH 13/15] Pushing Policy Worker:rough --- src/forge/actors/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index af5705c16..0da83bdc1 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -18,9 +18,9 @@ from forge.controller.spawn import spawn_service from forge.data.sharding import VLLMSharding + from forge.interfaces import Policy as PolicyInterface from forge.types import ProcessConfig - from monarch.actor import Actor, current_rank, endpoint, proc_mesh from omegaconf import DictConfig, OmegaConf from torchstore import MultiProcessStore From b96cbfe59c28dbaa6bbbed48ecebc9228cd2c312 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Mon, 25 Aug 2025 15:08:02 -0700 Subject: [PATCH 14/15] Partial paths for sapwn_actor and spawn service --- src/forge/actors/policy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 0da83bdc1..150f57ebf 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -18,7 +18,6 @@ from forge.controller.spawn import spawn_service from forge.data.sharding import VLLMSharding - from forge.interfaces import Policy as PolicyInterface from forge.types import ProcessConfig from monarch.actor import Actor, current_rank, endpoint, proc_mesh From 4d3700783a9447e8798a1a92112cfa0e3823b3d6 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Mon, 25 Aug 2025 16:19:30 -0700 Subject: [PATCH 15/15] Debugging Service Path --- src/forge/actors/policy.py | 26 +++++++++++++++++--------- src/forge/interfaces.py | 16 ---------------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 150f57ebf..1301f4113 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -111,9 +111,6 @@ async def setup(self): log_stats=None, ) - def should_spawn_workers(self) -> bool: - return True - async def spawn_workers(self): self.worker_mesh = await proc_mesh( gpus=self.config.num_workers, @@ -186,6 +183,7 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) + print("Awaiting in Generate") return await request_fut # Abstracted to match vllm @@ -200,9 +198,10 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i return request, 0 # Unused Arg: Current Wave @endpoint - async def run(self): + async def run_2(self): # TODO: add support for `iteration_stats` # TODO: move postprocessing out of loop to not block + print("Policy Run: starting") parallel_config = self.vllm_args.parallel_config output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size self.running = True @@ -292,6 +291,7 @@ async def setup(self, store: MultiProcessStore = None): # TODO: remove ["gpus"] when monarch implements a flat rank self.rank = current_rank()["gpus"] self.worker = self.setup_worker() + print("PolicyWorker setup complete, with rank: ", self.rank) @endpoint async def execute_model(self, schedule: SchedulerOutput): @@ -437,14 +437,22 @@ async def _test(config: DictConfig): "What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke" ) - with_service = False + with_service = True if with_service: # Update this condition path once Service has been refactored service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) - print("spawning service") - service = await spawn_service(service_config, Policy, config=config) - service.run() - responses: List[CompletionOutput] = await service.generate(prompt) + print("Spawning service") + + # ServiceInterface is a wrapper around the Policy + policy = await spawn_service(service_config, Policy, config=config) + session_id = await policy.start_session() + + print("Kick off background processing") + policy.run_2.choose() + + print("Request Generation") + responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) + await policy.terminate_session(session_id) else: process_config = ProcessConfig() diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 56f77ad44..b485fc791 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -87,22 +87,6 @@ async def update_weights(self): """Update the policy weights.""" pass - # TODO: Update Based on Service Refactor - - # @abstractmethod - # def should_spawn_workers(self) -> bool: - # """Whether the policy needs to separately spawrn child workers.""" - # pass - - # @abstractmethod - # def spawn_workers(self): - # """ - # Spawn child workers used by this actor - - # No-op when should_spawn_workers() is False. - # """ - # pass - class BaseTokenizer(ABC): """