Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 88 additions & 65 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
from typing import Dict, List

import torch
from forge.controller import spawn_actors
from forge.controller.service import ServiceConfig
from forge.controller.spawn import spawn_service

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
from omegaconf import DictConfig, OmegaConf
from torchstore import MultiProcessStore

from torchstore._state_dict_utils import DELIM

from vllm.engine.arg_utils import EngineArgs
Expand All @@ -39,24 +45,43 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.data.sharding import VLLMSharding

logger = logging.getLogger(__name__)


@dataclass
class PolicyRouter(Actor):
class Policy(PolicyInterface):
# TODO: Add dp support
policy: Actor
config: DictConfig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like individual components having knowledge of our frontend config system. Can we parameterize this fully or pass in a policy specific config class?

# Gets set up by setup worker
policy_worker: Actor = None

sampling_params: SamplingParams = None
lora_request: LoRARequest = None
tokenization_kwargs: dict = None

@endpoint
async def setup(self):
# Set up policy_worker
await self.spawn_workers()

self.request_id = 0
self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {}
self.vllm_args = await self.policy.get_vllm_args.choose()
self.vllm_args = await self.policy_worker.get_vllm_args.choose()

# Setup sampling params
sampling_overrides = self.config.sampling_params
overrides = {
"n": sampling_overrides.num_samples,
"guided_decoding": (
GuidedDecodingParams(choice=["Positive", "Negative"])
if sampling_overrides.guided_decoding
else None
),
}
self.sampling_params = get_default_sampling_params(
self.vllm_args, overrides=overrides
)

# Setup processors
# TODO: move all processing to the Environment
# TODO: add support for `log_stats` and `mm_registry`
Expand All @@ -70,9 +95,9 @@ async def setup(self):
)
self.output_processor = OutputProcessor(tokenizer, log_stats=None)

# Setup schduuler
# Setup scheduler
# TODO: Add support for `log_stats`
kv_cache_configs = await self.policy.setup_kv_cache.call()
kv_cache_configs = await self.policy_worker.setup_kv_cache.call()
kv_cache_config = kv_cache_configs._values[0]
self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
self.vllm_args.cache_config.num_cpu_blocks = 0
Expand All @@ -86,15 +111,26 @@ async def setup(self):
log_stats=None,
)

async def spawn_workers(self):
self.worker_mesh = await proc_mesh(
gpus=self.config.num_workers,
env={
"MASTER_ADDR": str(get_loopback_ip()),
"MASTER_PORT": str(get_open_port()),
},
)
self.policy_worker = await self.worker_mesh.spawn(
"policy_worker", PolicyWorker, **self.config.worker_params
)
await self.policy_worker.setup.call()

@endpoint
async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]:
self.request_id += 1 % sys.maxsize
request_id = str(self.request_id) # implement from a counter

# Wraps prompt into a dict
prompt: Dict[str, str] = convert_input(prompt)
if self.sampling_params is None:
self.sampling_params = get_default_sampling_params(self.vllm_args)

# truncate prmpt
tokenization_kwargs = self.tokenization_kwargs or {}
Expand Down Expand Up @@ -161,15 +197,17 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i
return request, 0 # Unused Arg: Current Wave

@endpoint
async def run(self):
async def run_processing(self):
# TODO: add support for `iteration_stats`
# TODO: move postprocessing out of loop to not block
parallel_config = self.vllm_args.parallel_config
output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size
self.running = True
while self.running:
scheduler_output = self.scheduler.schedule()
worker_outputs = await self.policy.execute_model.call(scheduler_output)
worker_outputs = await self.policy_worker.execute_model.call(
scheduler_output
)
worker_output = worker_outputs._values[output_rank]
outputs = self.scheduler.update_from_output(scheduler_output, worker_output)
outputs = outputs.get(0) or EngineCoreOutputs()
Expand All @@ -185,19 +223,23 @@ async def run(self):
_, fut = self.requests.pop(request_output.request_id)
fut.set_result(request_output.outputs)

@endpoint
async def update_weights(self):
"""Update the policy weights."""
pass

@endpoint
async def shutdown(self):
self.running = False


@dataclass
class Policy(Actor):
class PolicyWorker(Actor):
model: str
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
enforce_eager: bool = False
vllm_args: EngineArgs = None
resources: int = 1
state_dict_key: str = "model_state_dict"

def __post_init__(self):
Expand Down Expand Up @@ -240,7 +282,6 @@ def __post_init__(self):
setattr(self.vllm_args, key, value)
# Build Config
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)
assert self.vllm_args.parallel_config.world_size == self.resources

@endpoint
async def setup(self, store: MultiProcessStore = None):
Expand Down Expand Up @@ -387,67 +428,49 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams:
return params


async def _test(config, guided_decoding=False, num_samples=1):
# TODO: Create proper test
router_mesh = await proc_mesh(gpus=1)
policy_mesh = await proc_mesh(
gpus=config["resources"],
env={
"MASTER_ADDR": str(get_loopback_ip()),
"MASTER_PORT": str(get_open_port()),
},
# TODO: Create proper test
async def _test(config: DictConfig):
prompt = (
"What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke"
)
service_config = ServiceConfig(procs_per_replica=1, num_replicas=1)

policy_actor = await policy_mesh.spawn("policy", Policy, **config)

# TODO: Make this customizable from the config
overrides = {
"n": num_samples,
"guided_decoding": (
GuidedDecodingParams(choice=["Positive", "Negative"])
if guided_decoding
else None
),
}

vllm_args = await policy_actor.get_vllm_args.choose()
sampling_params = get_default_sampling_params(vllm_args, overrides=overrides)

router = await router_mesh.spawn(
"policy_router",
PolicyRouter,
policy=policy_actor,
sampling_params=sampling_params,
)
print("Spawning service")
policy = await spawn_service(service_config, Policy, config=config)
session_id = await policy.start_session()

await policy_actor.setup.call()
await router.setup.call()
print("Model setup")
print("Kick off background processing")
asyncio.create_task(policy.run_processing.call())

router.run.call()
print("Model running")
print("Request Generation")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)

prompt = "What is 3+5?" if guided_decoding else "Tell me a joke"
responses: List[CompletionOutput] = await router.generate.call_one(prompt)
print("Terminating session")
await policy.shutdown.call()
await policy.terminate_session(session_id)

for batch, response in enumerate(responses):
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print(f"Batch {batch}:")
print(f"User: {prompt}\nAssistant: {response.text}")
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")

await router.shutdown.call()


if __name__ == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the policy main into its own app? apps/vllm.py or something

I think having a standalone vLLM application will be useful for Stanford too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally let's do it in a. separate PR to unblock this one

config = {
"model": "meta-llama/Llama-3.1-8B-Instruct",
"tensor_parallel_size": 2,
"pipeline_parallel_size": 1,
"enforce_eager": True,
"resources": 2,
}
# asyncio.run(_test(config))
# asyncio.run(_test(config, guided_decoding=True))
# asyncio.run(_test(config, num_samples=2))
# asyncio.run(_test(config, guided_decoding=True, num_samples=3))
config = OmegaConf.create(
{
"num_workers": 2,
"worker_params": {
"model": "meta-llama/Llama-3.1-8B-Instruct",
"tensor_parallel_size": 2,
"pipeline_parallel_size": 1,
"enforce_eager": True,
"vllm_args": None,
},
"sampling_params": {
"guided_decoding": True,
"num_samples": 2,
},
}
)
asyncio.run(_test(config))
Loading