-
Notifications
You must be signed in to change notification settings - Fork 16
Leverage Services for Policy + Rename PolicyRouter #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,10 +13,16 @@ | |
from typing import Dict, List | ||
|
||
import torch | ||
from forge.controller import spawn_actors | ||
from forge.controller.service import ServiceConfig | ||
from forge.controller.spawn import spawn_service | ||
|
||
from forge.data.sharding import VLLMSharding | ||
from forge.interfaces import Policy as PolicyInterface | ||
from forge.types import ProcessConfig | ||
from monarch.actor import Actor, current_rank, endpoint, proc_mesh | ||
from omegaconf import DictConfig, OmegaConf | ||
from torchstore import MultiProcessStore | ||
|
||
from torchstore._state_dict_utils import DELIM | ||
|
||
from vllm.engine.arg_utils import EngineArgs | ||
|
@@ -39,24 +45,43 @@ | |
from vllm.v1.structured_output import StructuredOutputManager | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
from forge.data.sharding import VLLMSharding | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class PolicyRouter(Actor): | ||
class Policy(PolicyInterface): | ||
# TODO: Add dp support | ||
policy: Actor | ||
config: DictConfig | ||
# Gets set up by setup worker | ||
policy_worker: Actor = None | ||
|
||
sampling_params: SamplingParams = None | ||
lora_request: LoRARequest = None | ||
tokenization_kwargs: dict = None | ||
|
||
@endpoint | ||
async def setup(self): | ||
# Set up policy_worker | ||
await self.spawn_workers() | ||
|
||
self.request_id = 0 | ||
self.requests: Dict[str, Tuple[None | ParentRequest, asyncio.Future]] = {} | ||
self.vllm_args = await self.policy.get_vllm_args.choose() | ||
self.vllm_args = await self.policy_worker.get_vllm_args.choose() | ||
|
||
# Setup sampling params | ||
sampling_overrides = self.config.sampling_params | ||
overrides = { | ||
"n": sampling_overrides.num_samples, | ||
"guided_decoding": ( | ||
GuidedDecodingParams(choice=["Positive", "Negative"]) | ||
if sampling_overrides.guided_decoding | ||
else None | ||
), | ||
} | ||
self.sampling_params = get_default_sampling_params( | ||
self.vllm_args, overrides=overrides | ||
) | ||
|
||
# Setup processors | ||
# TODO: move all processing to the Environment | ||
# TODO: add support for `log_stats` and `mm_registry` | ||
|
@@ -70,9 +95,9 @@ async def setup(self): | |
) | ||
self.output_processor = OutputProcessor(tokenizer, log_stats=None) | ||
|
||
# Setup schduuler | ||
# Setup scheduler | ||
# TODO: Add support for `log_stats` | ||
kv_cache_configs = await self.policy.setup_kv_cache.call() | ||
kv_cache_configs = await self.policy_worker.setup_kv_cache.call() | ||
kv_cache_config = kv_cache_configs._values[0] | ||
self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks | ||
self.vllm_args.cache_config.num_cpu_blocks = 0 | ||
|
@@ -86,15 +111,26 @@ async def setup(self): | |
log_stats=None, | ||
) | ||
|
||
async def spawn_workers(self): | ||
self.worker_mesh = await proc_mesh( | ||
gpus=self.config.num_workers, | ||
env={ | ||
"MASTER_ADDR": str(get_loopback_ip()), | ||
"MASTER_PORT": str(get_open_port()), | ||
}, | ||
) | ||
self.policy_worker = await self.worker_mesh.spawn( | ||
"policy_worker", PolicyWorker, **self.config.worker_params | ||
) | ||
await self.policy_worker.setup.call() | ||
|
||
@endpoint | ||
async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: | ||
self.request_id += 1 % sys.maxsize | ||
request_id = str(self.request_id) # implement from a counter | ||
|
||
# Wraps prompt into a dict | ||
prompt: Dict[str, str] = convert_input(prompt) | ||
if self.sampling_params is None: | ||
self.sampling_params = get_default_sampling_params(self.vllm_args) | ||
|
||
# truncate prmpt | ||
tokenization_kwargs = self.tokenization_kwargs or {} | ||
|
@@ -161,15 +197,17 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i | |
return request, 0 # Unused Arg: Current Wave | ||
|
||
@endpoint | ||
async def run(self): | ||
async def run_processing(self): | ||
# TODO: add support for `iteration_stats` | ||
# TODO: move postprocessing out of loop to not block | ||
parallel_config = self.vllm_args.parallel_config | ||
output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size | ||
self.running = True | ||
while self.running: | ||
scheduler_output = self.scheduler.schedule() | ||
worker_outputs = await self.policy.execute_model.call(scheduler_output) | ||
worker_outputs = await self.policy_worker.execute_model.call( | ||
scheduler_output | ||
) | ||
worker_output = worker_outputs._values[output_rank] | ||
outputs = self.scheduler.update_from_output(scheduler_output, worker_output) | ||
outputs = outputs.get(0) or EngineCoreOutputs() | ||
|
@@ -185,19 +223,23 @@ async def run(self): | |
_, fut = self.requests.pop(request_output.request_id) | ||
fut.set_result(request_output.outputs) | ||
|
||
@endpoint | ||
async def update_weights(self): | ||
"""Update the policy weights.""" | ||
pass | ||
|
||
@endpoint | ||
async def shutdown(self): | ||
self.running = False | ||
|
||
|
||
@dataclass | ||
class Policy(Actor): | ||
class PolicyWorker(Actor): | ||
model: str | ||
tensor_parallel_size: int = 1 | ||
pipeline_parallel_size: int = 1 | ||
enforce_eager: bool = False | ||
vllm_args: EngineArgs = None | ||
resources: int = 1 | ||
state_dict_key: str = "model_state_dict" | ||
|
||
def __post_init__(self): | ||
|
@@ -240,7 +282,6 @@ def __post_init__(self): | |
setattr(self.vllm_args, key, value) | ||
# Build Config | ||
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) | ||
assert self.vllm_args.parallel_config.world_size == self.resources | ||
|
||
@endpoint | ||
async def setup(self, store: MultiProcessStore = None): | ||
|
@@ -387,67 +428,49 @@ def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: | |
return params | ||
|
||
|
||
async def _test(config, guided_decoding=False, num_samples=1): | ||
# TODO: Create proper test | ||
router_mesh = await proc_mesh(gpus=1) | ||
policy_mesh = await proc_mesh( | ||
gpus=config["resources"], | ||
env={ | ||
"MASTER_ADDR": str(get_loopback_ip()), | ||
"MASTER_PORT": str(get_open_port()), | ||
}, | ||
# TODO: Create proper test | ||
async def _test(config: DictConfig): | ||
prompt = ( | ||
"What is 3+5?" if config.sampling_params.guided_decoding else "Tell me a joke" | ||
) | ||
service_config = ServiceConfig(procs_per_replica=1, num_replicas=1) | ||
|
||
policy_actor = await policy_mesh.spawn("policy", Policy, **config) | ||
|
||
# TODO: Make this customizable from the config | ||
overrides = { | ||
"n": num_samples, | ||
"guided_decoding": ( | ||
GuidedDecodingParams(choice=["Positive", "Negative"]) | ||
if guided_decoding | ||
else None | ||
), | ||
} | ||
|
||
vllm_args = await policy_actor.get_vllm_args.choose() | ||
sampling_params = get_default_sampling_params(vllm_args, overrides=overrides) | ||
|
||
router = await router_mesh.spawn( | ||
"policy_router", | ||
PolicyRouter, | ||
policy=policy_actor, | ||
sampling_params=sampling_params, | ||
) | ||
print("Spawning service") | ||
policy = await spawn_service(service_config, Policy, config=config) | ||
session_id = await policy.start_session() | ||
|
||
await policy_actor.setup.call() | ||
await router.setup.call() | ||
print("Model setup") | ||
print("Kick off background processing") | ||
asyncio.create_task(policy.run_processing.call()) | ||
|
||
router.run.call() | ||
print("Model running") | ||
print("Request Generation") | ||
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) | ||
|
||
prompt = "What is 3+5?" if guided_decoding else "Tell me a joke" | ||
responses: List[CompletionOutput] = await router.generate.call_one(prompt) | ||
print("Terminating session") | ||
await policy.shutdown.call() | ||
await policy.terminate_session(session_id) | ||
|
||
for batch, response in enumerate(responses): | ||
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") | ||
print(f"Batch {batch}:") | ||
print(f"User: {prompt}\nAssistant: {response.text}") | ||
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") | ||
|
||
await router.shutdown.call() | ||
|
||
|
||
if __name__ == "__main__": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move the policy main into its own app? apps/vllm.py or something I think having a standalone vLLM application will be useful for Stanford too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally let's do it in a. separate PR to unblock this one |
||
config = { | ||
"model": "meta-llama/Llama-3.1-8B-Instruct", | ||
"tensor_parallel_size": 2, | ||
"pipeline_parallel_size": 1, | ||
"enforce_eager": True, | ||
"resources": 2, | ||
} | ||
# asyncio.run(_test(config)) | ||
# asyncio.run(_test(config, guided_decoding=True)) | ||
# asyncio.run(_test(config, num_samples=2)) | ||
# asyncio.run(_test(config, guided_decoding=True, num_samples=3)) | ||
config = OmegaConf.create( | ||
{ | ||
"num_workers": 2, | ||
"worker_params": { | ||
"model": "meta-llama/Llama-3.1-8B-Instruct", | ||
"tensor_parallel_size": 2, | ||
"pipeline_parallel_size": 1, | ||
"enforce_eager": True, | ||
"vllm_args": None, | ||
}, | ||
"sampling_params": { | ||
"guided_decoding": True, | ||
"num_samples": 2, | ||
}, | ||
} | ||
) | ||
asyncio.run(_test(config)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like individual components having knowledge of our frontend config system. Can we parameterize this fully or pass in a policy specific config class?