diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 6866a6949..20f0744bf 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -9,14 +9,20 @@ import os import sys from copy import copy -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Dict, List import torch +from forge.controller import spawn_actors +from forge.controller.service import ServiceConfig +from forge.controller.spawn import spawn_service +from forge.data.sharding import VLLMSharding +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import Actor, current_rank, endpoint, proc_mesh +from omegaconf import DictConfig, OmegaConf from torchstore import MultiProcessStore - from torchstore._state_dict_utils import DELIM from vllm.engine.arg_utils import EngineArgs @@ -39,24 +45,87 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.data.sharding import VLLMSharding - logger = logging.getLogger(__name__) @dataclass -class PolicyRouter(Actor): - # TODO: Add dp support - policy: Actor +class SamplingOverrides: + """ + Overrides for vLLMs sampling params. + + Note: We'll want to tie this closer to or directly use vllm's + SamplingParams. It is currently used to track a supported + subset + + Args: + num_samples: Number of samples to generate. + guided_decoding: Whether to use guided decoding. + """ + + num_samples: int + guided_decoding: bool = False + + +@dataclass +class WorkerConfig: + """ + Config args used for setting up the policy worker. + + Args: + model: Model name. + tensor_parallel_size: Number of tensor parallel workers. + pipeline_parallel_size: Number of pipeline parallel workers. + enforce_eager: Whether to enforce eager mode. + vllm_args: vLLM engine args. + """ + + model: str + tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + enforce_eager: bool = False + vllm_args: EngineArgs = None + + +@dataclass +class PolicyConfig: + num_workers: int + worker_params: WorkerConfig + sampling_params: SamplingOverrides + + +@dataclass +class Policy(PolicyInterface): + config: PolicyConfig + # Gets set up by setup + policy_worker: Actor = None + sampling_params: SamplingParams = None lora_request: LoRARequest = None tokenization_kwargs: dict = None @endpoint async def setup(self): + # Set up policy_worker + await self.spawn_workers() + self.request_id = 0 self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {} - self.vllm_args = await self.policy.get_vllm_args.choose() + self.vllm_args = await self.policy_worker.get_vllm_args.choose() + + # Setup sampling params + sampling_overrides = self.config.sampling_params + overrides = { + "n": sampling_overrides.num_samples, + "guided_decoding": ( + GuidedDecodingParams(choice=["Positive", "Negative"]) + if sampling_overrides.guided_decoding + else None + ), + } + self.sampling_params = get_default_sampling_params( + self.vllm_args, overrides=overrides + ) + # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` @@ -70,9 +139,9 @@ async def setup(self): ) self.output_processor = OutputProcessor(tokenizer, log_stats=None) - # Setup schduuler + # Setup scheduler # TODO: Add support for `log_stats` - kv_cache_configs = await self.policy.setup_kv_cache.call() + kv_cache_configs = await self.policy_worker.setup_kv_cache.call() kv_cache_config = kv_cache_configs._values[0] self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks self.vllm_args.cache_config.num_cpu_blocks = 0 @@ -86,6 +155,19 @@ async def setup(self): log_stats=None, ) + async def spawn_workers(self): + self.worker_mesh = await proc_mesh( + gpus=self.config.num_workers, + env={ + "MASTER_ADDR": str(get_loopback_ip()), + "MASTER_PORT": str(get_open_port()), + }, + ) + self.policy_worker = await self.worker_mesh.spawn( + "policy_worker", PolicyWorker, **asdict(self.config.worker_params) + ) + await self.policy_worker.setup.call() + @endpoint async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: self.request_id += 1 % sys.maxsize @@ -93,8 +175,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu # Wraps prompt into a dict prompt: Dict[str, str] = convert_input(prompt) - if self.sampling_params is None: - self.sampling_params = get_default_sampling_params(self.vllm_args) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} @@ -161,7 +241,7 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i return request, 0 # Unused Arg: Current Wave @endpoint - async def run(self): + async def run_processing(self): # TODO: add support for `iteration_stats` # TODO: move postprocessing out of loop to not block parallel_config = self.vllm_args.parallel_config @@ -169,7 +249,9 @@ async def run(self): self.running = True while self.running: scheduler_output = self.scheduler.schedule() - worker_outputs = await self.policy.execute_model.call(scheduler_output) + worker_outputs = await self.policy_worker.execute_model.call( + scheduler_output + ) worker_output = worker_outputs._values[output_rank] outputs = self.scheduler.update_from_output(scheduler_output, worker_output) outputs = outputs.get(0) or EngineCoreOutputs() @@ -185,19 +267,23 @@ async def run(self): _, fut = self.requests.pop(request_output.request_id) fut.set_result(request_output.outputs) + @endpoint + async def update_weights(self): + """Update the policy weights.""" + pass + @endpoint async def shutdown(self): self.running = False @dataclass -class Policy(Actor): +class PolicyWorker(Actor): model: str tensor_parallel_size: int = 1 pipeline_parallel_size: int = 1 enforce_eager: bool = False vllm_args: EngineArgs = None - resources: int = 1 state_dict_key: str = "model_state_dict" def __post_init__(self): @@ -240,7 +326,6 @@ def __post_init__(self): setattr(self.vllm_args, key, value) # Build Config self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) - assert self.vllm_args.parallel_config.world_size == self.resources @endpoint async def setup(self, store: MultiProcessStore = None): @@ -387,48 +472,26 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: return params -async def _test(config, guided_decoding=False, num_samples=1): - # TODO: Create proper test - router_mesh = await proc_mesh(gpus=1) - policy_mesh = await proc_mesh( - gpus=config["resources"], - env={ - "MASTER_ADDR": str(get_loopback_ip()), - "MASTER_PORT": str(get_open_port()), - }, +# TODO: Create proper test +async def _test(config: DictConfig): + prompt = ( + "What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke" ) + service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) - policy_actor = await policy_mesh.spawn("policy", Policy, **config) - - # TODO: Make this customizable from the config - overrides = { - "n": num_samples, - "guided_decoding": ( - GuidedDecodingParams(choice=["Positive", "Negative"]) - if guided_decoding - else None - ), - } - - vllm_args = await policy_actor.get_vllm_args.choose() - sampling_params = get_default_sampling_params(vllm_args, overrides=overrides) - - router = await router_mesh.spawn( - "policy_router", - PolicyRouter, - policy=policy_actor, - sampling_params=sampling_params, - ) + print("Spawning service") + policy = await spawn_service(service_config, Policy, config=config) + session_id = await policy.start_session() - await policy_actor.setup.call() - await router.setup.call() - print("Model setup") + print("Kick off background processing") + asyncio.create_task(policy.run_processing.call()) - router.run.call() - print("Model running") + print("Request Generation") + responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) - prompt = "What is 3+5?" if guided_decoding else "Tell me a joke" - responses: List[CompletionOutput] = await router.generate.call_one(prompt) + print("Terminating session") + await policy.shutdown.call() + await policy.terminate_session(session_id) for batch, response in enumerate(responses): print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") @@ -436,18 +499,20 @@ async def _test(config, guided_decoding=False, num_samples=1): print(f"User: {prompt}\nAssistant: {response.text}") print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") - await router.shutdown.call() - if __name__ == "__main__": - config = { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 1, - "enforce_eager": True, - "resources": 2, - } - # asyncio.run(_test(config)) - # asyncio.run(_test(config, guided_decoding=True)) - # asyncio.run(_test(config, num_samples=2)) - # asyncio.run(_test(config, guided_decoding=True, num_samples=3)) + config = PolicyConfig( + num_workers=2, + worker_params=WorkerConfig( + model="meta-llama/Llama-3.1-8B-Instruct", + tensor_parallel_size=2, + pipeline_parallel_size=1, + enforce_eager=True, + vllm_args=None, + ), + sampling_params=SamplingOverrides( + num_samples=2, + guided_decoding=True, + ), + ) + asyncio.run(_test(config))