Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 132 additions & 67 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
import os
import sys
from copy import copy
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Dict, List

import torch
from forge.controller import spawn_actors
from forge.controller.service import ServiceConfig
from forge.controller.spawn import spawn_service

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
from omegaconf import DictConfig, OmegaConf
from torchstore import MultiProcessStore

from torchstore._state_dict_utils import DELIM

from vllm.engine.arg_utils import EngineArgs
Expand All @@ -39,24 +45,87 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.data.sharding import VLLMSharding

logger = logging.getLogger(__name__)


@dataclass
class PolicyRouter(Actor):
# TODO: Add dp support
policy: Actor
class SamplingOverrides:
"""
Overrides for vLLMs sampling params.

Note: We'll want to tie this closer to or directly use vllm's
SamplingParams. It is currently used to track a supported
subset

Args:
num_samples: Number of samples to generate.
guided_decoding: Whether to use guided decoding.
"""

num_samples: int
guided_decoding: bool = False


@dataclass
class WorkerConfig:
"""
Config args used for setting up the policy worker.

Args:
model: Model name.
tensor_parallel_size: Number of tensor parallel workers.
pipeline_parallel_size: Number of pipeline parallel workers.
enforce_eager: Whether to enforce eager mode.
vllm_args: vLLM engine args.
"""

model: str
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
enforce_eager: bool = False
vllm_args: EngineArgs = None


@dataclass
class PolicyConfig:
num_workers: int
worker_params: WorkerConfig
sampling_params: SamplingOverrides


@dataclass
class Policy(PolicyInterface):
config: PolicyConfig
# Gets set up by setup
policy_worker: Actor = None

sampling_params: SamplingParams = None
lora_request: LoRARequest = None
tokenization_kwargs: dict = None

@endpoint
async def setup(self):
# Set up policy_worker
await self.spawn_workers()

self.request_id = 0
self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {}
self.vllm_args = await self.policy.get_vllm_args.choose()
self.vllm_args = await self.policy_worker.get_vllm_args.choose()

# Setup sampling params
sampling_overrides = self.config.sampling_params
overrides = {
"n": sampling_overrides.num_samples,
"guided_decoding": (
GuidedDecodingParams(choice=["Positive", "Negative"])
if sampling_overrides.guided_decoding
else None
),
}
self.sampling_params = get_default_sampling_params(
self.vllm_args, overrides=overrides
)

# Setup processors
# TODO: move all processing to the Environment
# TODO: add support for `log_stats` and `mm_registry`
Expand All @@ -70,9 +139,9 @@ async def setup(self):
)
self.output_processor = OutputProcessor(tokenizer, log_stats=None)

# Setup schduuler
# Setup scheduler
# TODO: Add support for `log_stats`
kv_cache_configs = await self.policy.setup_kv_cache.call()
kv_cache_configs = await self.policy_worker.setup_kv_cache.call()
kv_cache_config = kv_cache_configs._values[0]
self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
self.vllm_args.cache_config.num_cpu_blocks = 0
Expand All @@ -86,15 +155,26 @@ async def setup(self):
log_stats=None,
)

async def spawn_workers(self):
self.worker_mesh = await proc_mesh(
gpus=self.config.num_workers,
env={
"MASTER_ADDR": str(get_loopback_ip()),
"MASTER_PORT": str(get_open_port()),
},
)
self.policy_worker = await self.worker_mesh.spawn(
"policy_worker", PolicyWorker, **asdict(self.config.worker_params)
)
await self.policy_worker.setup.call()

@endpoint
async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]:
self.request_id += 1 % sys.maxsize
request_id = str(self.request_id) # implement from a counter

# Wraps prompt into a dict
prompt: Dict[str, str] = convert_input(prompt)
if self.sampling_params is None:
self.sampling_params = get_default_sampling_params(self.vllm_args)

# truncate prmpt
tokenization_kwargs = self.tokenization_kwargs or {}
Expand Down Expand Up @@ -161,15 +241,17 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i
return request, 0 # Unused Arg: Current Wave

@endpoint
async def run(self):
async def run_processing(self):
# TODO: add support for `iteration_stats`
# TODO: move postprocessing out of loop to not block
parallel_config = self.vllm_args.parallel_config
output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size
self.running = True
while self.running:
scheduler_output = self.scheduler.schedule()
worker_outputs = await self.policy.execute_model.call(scheduler_output)
worker_outputs = await self.policy_worker.execute_model.call(
scheduler_output
)
worker_output = worker_outputs._values[output_rank]
outputs = self.scheduler.update_from_output(scheduler_output, worker_output)
outputs = outputs.get(0) or EngineCoreOutputs()
Expand All @@ -185,19 +267,23 @@ async def run(self):
_, fut = self.requests.pop(request_output.request_id)
fut.set_result(request_output.outputs)

@endpoint
async def update_weights(self):
"""Update the policy weights."""
pass

@endpoint
async def shutdown(self):
self.running = False


@dataclass
class Policy(Actor):
class PolicyWorker(Actor):
model: str
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
enforce_eager: bool = False
vllm_args: EngineArgs = None
resources: int = 1
state_dict_key: str = "model_state_dict"

def __post_init__(self):
Expand Down Expand Up @@ -240,7 +326,6 @@ def __post_init__(self):
setattr(self.vllm_args, key, value)
# Build Config
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)
assert self.vllm_args.parallel_config.world_size == self.resources

@endpoint
async def setup(self, store: MultiProcessStore = None):
Expand Down Expand Up @@ -387,67 +472,47 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams:
return params


async def _test(config, guided_decoding=False, num_samples=1):
# TODO: Create proper test
router_mesh = await proc_mesh(gpus=1)
policy_mesh = await proc_mesh(
gpus=config["resources"],
env={
"MASTER_ADDR": str(get_loopback_ip()),
"MASTER_PORT": str(get_open_port()),
},
# TODO: Create proper test
async def _test(config: DictConfig):
prompt = (
"What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke"
)
service_config = ServiceConfig(procs_per_replica=1, num_replicas=1)

policy_actor = await policy_mesh.spawn("policy", Policy, **config)

# TODO: Make this customizable from the config
overrides = {
"n": num_samples,
"guided_decoding": (
GuidedDecodingParams(choice=["Positive", "Negative"])
if guided_decoding
else None
),
}

vllm_args = await policy_actor.get_vllm_args.choose()
sampling_params = get_default_sampling_params(vllm_args, overrides=overrides)

router = await router_mesh.spawn(
"policy_router",
PolicyRouter,
policy=policy_actor,
sampling_params=sampling_params,
)
print("Spawning service")
policy = await spawn_service(service_config, Policy, config=config)
session_id = await policy.start_session()

await policy_actor.setup.call()
await router.setup.call()
print("Model setup")
print("Kick off background processing")
asyncio.create_task(policy.run_processing.call())

router.run.call()
print("Model running")
print("Request Generation")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)

prompt = "What is 3+5?" if guided_decoding else "Tell me a joke"
responses: List[CompletionOutput] = await router.generate.call_one(prompt)
print("Terminating session")
await policy.shutdown.call()
await policy.terminate_session(session_id)

for batch, response in enumerate(responses):
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print(f"Batch {batch}:")
print(f"User: {prompt}\nAssistant: {response.text}")
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")

await router.shutdown.call()


if __name__ == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the policy main into its own app? apps/vllm.py or something

I think having a standalone vLLM application will be useful for Stanford too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally let's do it in a. separate PR to unblock this one

config = {
"model": "meta-llama/Llama-3.1-8B-Instruct",
"tensor_parallel_size": 2,
"pipeline_parallel_size": 1,
"enforce_eager": True,
"resources": 2,
}
# asyncio.run(_test(config))
# asyncio.run(_test(config, guided_decoding=True))
# asyncio.run(_test(config, num_samples=2))
# asyncio.run(_test(config, guided_decoding=True, num_samples=3))
config = PolicyConfig(
num_workers=2,
worker_params=WorkerConfig(
model="meta-llama/Llama-3.1-8B-Instruct",
tensor_parallel_size=2,
pipeline_parallel_size=1,
enforce_eager=True,
vllm_args=None,
),
sampling_params=SamplingOverrides(
num_samples=2,
guided_decoding=True,
),
)
asyncio.run(_test(config))
Loading