From fcc798a6f49d19c54263512f5f59af7bbd642a34 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Mon, 25 Aug 2025 16:47:59 -0700 Subject: [PATCH 1/3] [Debugging] Push PolicyWorker into Router and leverage Services --- src/forge/actors/policy.py | 172 +++++++++++++++++++++++-------------- 1 file changed, 107 insertions(+), 65 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 6866a6949..1329acac3 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,10 +13,16 @@ from typing import Dict, List import torch +from forge.controller import spawn_actors +from forge.controller.service import ServiceConfig +from forge.controller.spawn import spawn_service +from forge.data.sharding import VLLMSharding +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import Actor, current_rank, endpoint, proc_mesh +from omegaconf import DictConfig, OmegaConf from torchstore import MultiProcessStore - from torchstore._state_dict_utils import DELIM from vllm.engine.arg_utils import EngineArgs @@ -39,24 +45,43 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.data.sharding import VLLMSharding - logger = logging.getLogger(__name__) @dataclass -class PolicyRouter(Actor): +class Policy(PolicyInterface): # TODO: Add dp support - policy: Actor + config: DictConfig + # Gets set up by setup worker + policy_worker: Actor = None + sampling_params: SamplingParams = None lora_request: LoRARequest = None tokenization_kwargs: dict = None @endpoint async def setup(self): + # Set up policy_worker + await self.spawn_workers() + self.request_id = 0 self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {} - self.vllm_args = await self.policy.get_vllm_args.choose() + self.vllm_args = await self.policy_worker.get_vllm_args.choose() + + # Setup sampling params + sampling_overrides = self.config.sampling_params + overrides = { + "n": sampling_overrides.num_samples, + "guided_decoding": ( + GuidedDecodingParams(choice=["Positive", "Negative"]) + if sampling_overrides.guided_decoding + else None + ), + } + self.sampling_params = get_default_sampling_params( + self.vllm_args, overrides=overrides + ) + # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` @@ -70,9 +95,9 @@ async def setup(self): ) self.output_processor = OutputProcessor(tokenizer, log_stats=None) - # Setup schduuler + # Setup scheduler # TODO: Add support for `log_stats` - kv_cache_configs = await self.policy.setup_kv_cache.call() + kv_cache_configs = await self.policy_worker.setup_kv_cache.call() kv_cache_config = kv_cache_configs._values[0] self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks self.vllm_args.cache_config.num_cpu_blocks = 0 @@ -86,6 +111,19 @@ async def setup(self): log_stats=None, ) + async def spawn_workers(self): + self.worker_mesh = await proc_mesh( + gpus=self.config.num_workers, + env={ + "MASTER_ADDR": str(get_loopback_ip()), + "MASTER_PORT": str(get_open_port()), + }, + ) + self.policy_worker = await self.worker_mesh.spawn( + "policy_worker", PolicyWorker, **self.config.worker_params + ) + await self.policy_worker.setup.call() + @endpoint async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: self.request_id += 1 % sys.maxsize @@ -93,8 +131,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu # Wraps prompt into a dict prompt: Dict[str, str] = convert_input(prompt) - if self.sampling_params is None: - self.sampling_params = get_default_sampling_params(self.vllm_args) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} @@ -147,6 +183,7 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) + print("Awaiting in Generate") return await request_fut # Abstracted to match vllm @@ -161,15 +198,18 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i return request, 0 # Unused Arg: Current Wave @endpoint - async def run(self): + async def run_processing(self): # TODO: add support for `iteration_stats` # TODO: move postprocessing out of loop to not block + print("Policy Run: starting") parallel_config = self.vllm_args.parallel_config output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size self.running = True while self.running: scheduler_output = self.scheduler.schedule() - worker_outputs = await self.policy.execute_model.call(scheduler_output) + worker_outputs = await self.policy_worker.execute_model.call( + scheduler_output + ) worker_output = worker_outputs._values[output_rank] outputs = self.scheduler.update_from_output(scheduler_output, worker_output) outputs = outputs.get(0) or EngineCoreOutputs() @@ -185,19 +225,23 @@ async def run(self): _, fut = self.requests.pop(request_output.request_id) fut.set_result(request_output.outputs) + @endpoint + async def update_weights(self): + """Update the policy weights.""" + pass + @endpoint async def shutdown(self): self.running = False @dataclass -class Policy(Actor): +class PolicyWorker(Actor): model: str tensor_parallel_size: int = 1 pipeline_parallel_size: int = 1 enforce_eager: bool = False vllm_args: EngineArgs = None - resources: int = 1 state_dict_key: str = "model_state_dict" def __post_init__(self): @@ -240,7 +284,6 @@ def __post_init__(self): setattr(self.vllm_args, key, value) # Build Config self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) - assert self.vllm_args.parallel_config.world_size == self.resources @endpoint async def setup(self, store: MultiProcessStore = None): @@ -387,48 +430,43 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: return params -async def _test(config, guided_decoding=False, num_samples=1): - # TODO: Create proper test - router_mesh = await proc_mesh(gpus=1) - policy_mesh = await proc_mesh( - gpus=config["resources"], - env={ - "MASTER_ADDR": str(get_loopback_ip()), - "MASTER_PORT": str(get_open_port()), - }, +# TODO: Create proper test +async def _test(config: DictConfig): + prompt = ( + "What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke" ) - policy_actor = await policy_mesh.spawn("policy", Policy, **config) - - # TODO: Make this customizable from the config - overrides = { - "n": num_samples, - "guided_decoding": ( - GuidedDecodingParams(choice=["Positive", "Negative"]) - if guided_decoding - else None - ), - } - - vllm_args = await policy_actor.get_vllm_args.choose() - sampling_params = get_default_sampling_params(vllm_args, overrides=overrides) - - router = await router_mesh.spawn( - "policy_router", - PolicyRouter, - policy=policy_actor, - sampling_params=sampling_params, - ) + with_service = True + if with_service: + # Update this condition path once Service has been refactored + service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) + print("Spawning service") + + # ServiceInterface is a wrapper around the Policy + policy = await spawn_service(service_config, Policy, config=config) + session_id = await policy.start_session() - await policy_actor.setup.call() - await router.setup.call() - print("Model setup") + print("Kick off background processing") + policy.run_processing.choose() - router.run.call() - print("Model running") + print("Request Generation") + responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) - prompt = "What is 3+5?" if guided_decoding else "Tell me a joke" - responses: List[CompletionOutput] = await router.generate.call_one(prompt) + print("Terminating session") + await policy.shutdown.call() + await policy.terminate_session(session_id) + + else: + process_config = ProcessConfig() + policy = await spawn_actors( + name="policy", actor_cls=Policy, cfg=config, processes=process_config + ) + print("Model setup") + await policy.setup.call() + print("Model running") + policy.run_processing.call() + responses: List[CompletionOutput] = await policy.generate.call_one(prompt) + await policy.shutdown.call() for batch, response in enumerate(responses): print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") @@ -436,18 +474,22 @@ async def _test(config, guided_decoding=False, num_samples=1): print(f"User: {prompt}\nAssistant: {response.text}") print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - await router.shutdown.call() - if __name__ == "__main__": - config = { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 1, - "enforce_eager": True, - "resources": 2, - } - # asyncio.run(_test(config)) - # asyncio.run(_test(config, guided_decoding=True)) - # asyncio.run(_test(config, num_samples=2)) - # asyncio.run(_test(config, guided_decoding=True, num_samples=3)) + config = OmegaConf.create( + { + "num_workers": 2, + "worker_params": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 1, + "enforce_eager": True, + "vllm_args": None, + }, + "sampling_params": { + "guided_decoding": True, + "num_samples": 2, + }, + } + ) + asyncio.run(_test(config)) From 865cf9b9ad2aeda0a8a6437004ca3bd5dc3ba9c4 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Mon, 25 Aug 2025 17:40:13 -0700 Subject: [PATCH 2/3] Remove prints and add async call --- src/forge/actors/policy.py | 41 ++++++++++---------------------------- 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 1329acac3..9a8eeac2a 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -183,7 +183,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) - print("Awaiting in Generate") return await request_fut # Abstracted to match vllm @@ -201,7 +200,6 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i async def run_processing(self): # TODO: add support for `iteration_stats` # TODO: move postprocessing out of loop to not block - print("Policy Run: starting") parallel_config = self.vllm_args.parallel_config output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size self.running = True @@ -435,38 +433,21 @@ async def _test(config: DictConfig): prompt = ( "What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke" ) + service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) - with_service = True - if with_service: - # Update this condition path once Service has been refactored - service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) - print("Spawning service") + print("Spawning service") + policy = await spawn_service(service_config, Policy, config=config) + session_id = await policy.start_session() - # ServiceInterface is a wrapper around the Policy - policy = await spawn_service(service_config, Policy, config=config) - session_id = await policy.start_session() + print("Kick off background processing") + asyncio.create_task(policy.run_processing.call()) - print("Kick off background processing") - policy.run_processing.choose() + print("Request Generation") + responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) - print("Request Generation") - responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) - - print("Terminating session") - await policy.shutdown.call() - await policy.terminate_session(session_id) - - else: - process_config = ProcessConfig() - policy = await spawn_actors( - name="policy", actor_cls=Policy, cfg=config, processes=process_config - ) - print("Model setup") - await policy.setup.call() - print("Model running") - policy.run_processing.call() - responses: List[CompletionOutput] = await policy.generate.call_one(prompt) - await policy.shutdown.call() + print("Terminating session") + await policy.shutdown.call() + await policy.terminate_session(session_id) for batch, response in enumerate(responses): print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") From ed0508f6ccdd809a022c57b42bb10654cad7c50c Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 26 Aug 2025 10:34:46 -0700 Subject: [PATCH 3/3] Explicit Config classes --- src/forge/actors/policy.py | 82 ++++++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 20 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 9a8eeac2a..20f0744bf 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -9,7 +9,7 @@ import os import sys from copy import copy -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Dict, List import torch @@ -48,11 +48,55 @@ logger = logging.getLogger(__name__) +@dataclass +class SamplingOverrides: + """ + Overrides for vLLMs sampling params. + + Note: We'll want to tie this closer to or directly use vllm's + SamplingParams. It is currently used to track a supported + subset + + Args: + num_samples: Number of samples to generate. + guided_decoding: Whether to use guided decoding. + """ + + num_samples: int + guided_decoding: bool = False + + +@dataclass +class WorkerConfig: + """ + Config args used for setting up the policy worker. + + Args: + model: Model name. + tensor_parallel_size: Number of tensor parallel workers. + pipeline_parallel_size: Number of pipeline parallel workers. + enforce_eager: Whether to enforce eager mode. + vllm_args: vLLM engine args. + """ + + model: str + tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + enforce_eager: bool = False + vllm_args: EngineArgs = None + + +@dataclass +class PolicyConfig: + num_workers: int + worker_params: WorkerConfig + sampling_params: SamplingOverrides + + @dataclass class Policy(PolicyInterface): - # TODO: Add dp support - config: DictConfig - # Gets set up by setup worker + config: PolicyConfig + # Gets set up by setup policy_worker: Actor = None sampling_params: SamplingParams = None @@ -120,7 +164,7 @@ async def spawn_workers(self): }, ) self.policy_worker = await self.worker_mesh.spawn( - "policy_worker", PolicyWorker, **self.config.worker_params + "policy_worker", PolicyWorker, **asdict(self.config.worker_params) ) await self.policy_worker.setup.call() @@ -457,20 +501,18 @@ async def _test(config: DictConfig): if __name__ == "__main__": - config = OmegaConf.create( - { - "num_workers": 2, - "worker_params": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 1, - "enforce_eager": True, - "vllm_args": None, - }, - "sampling_params": { - "guided_decoding": True, - "num_samples": 2, - }, - } + config = PolicyConfig( + num_workers=2, + worker_params=WorkerConfig( + model="meta-llama/Llama-3.1-8B-Instruct", + tensor_parallel_size=2, + pipeline_parallel_size=1, + enforce_eager=True, + vllm_args=None, + ), + sampling_params=SamplingOverrides( + num_samples=2, + guided_decoding=True, + ), ) asyncio.run(_test(config))