Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 67 additions & 64 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
Expand Down Expand Up @@ -53,7 +52,6 @@
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt
from forge.env import TORCHSTORE_USE_RDMA
from forge.interfaces import Policy as PolicyInterface
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig
Expand All @@ -63,7 +61,7 @@


@dataclass
class Policy(PolicyInterface):
class Policy(ForgeActor):
"""Instance of a vLLM-based Policy.

This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The
Expand All @@ -72,8 +70,8 @@ class Policy(PolicyInterface):
Args:
engine_args (EngineArgs): The engine arguments to use for the vLLM engine.
sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine.
available_devices (str): The available devices to use for the vLLM engine.
use_dcp (bool): Whether to use DCP for NFS-based weight sync.
use_dcp (bool): Whether to use DCP for NFS-based weight sync. Default depends on whether or not
RDMA is enabled in torchstore.

Example:
>>> policy = await Policy.options(procs=1, num_replicas=1, with_gpus=True).as_service(
Expand All @@ -88,19 +86,13 @@ class Policy(PolicyInterface):

engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
available_devices: str | None = None
use_dcp: bool = (
TORCHSTORE_USE_RDMA.get_value() == 0
) # torchstore currently only accepts 0 or 1
# Remaining variables are initialized in self.setup()
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
policy_worker: PolicyWorker | None = None
use_dcp: bool | None = None

def __post_init__(self):
super().__init__()
self._run_task: asyncio.Task | None = None
self._policy_proc: ProcMesh | None = None
self.worker: PolicyWorker | None = None
self._worker_procs: ProcMesh | None = None
self.running = False
self.policy_version: int = 0
Expand All @@ -113,16 +105,18 @@ def __post_init__(self):
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY

if self.use_dcp is None:
self.use_dcp = TORCHSTORE_USE_RDMA.get_value() == 0

@endpoint
async def register_worker(self, worker: PolicyWorker) -> None:
self.worker = worker
logger.debug("Registered PolicyWorker on Policy.")

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Policy"],
*,
engine_args: EngineArgs | Mapping = EngineArgs(),
sampling_params: SamplingParams | Mapping = SamplingParams(),
available_devices: str | None = None,
use_dcp: bool = (
TORCHSTORE_USE_RDMA.get_value() == 0
), # torchstore currently only accepts 0 or 1
*args,
**kwargs,
) -> "Policy":
"""Launch the policy with its workers.
Expand Down Expand Up @@ -154,45 +148,47 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
policy_proc_config.with_gpus = False
policy_proc = await get_proc_mesh(process_config=policy_proc_config)

if isinstance(engine_args, Mapping):
engine_args = EngineArgs(**engine_args)
engine_args._is_v1_supported_oracle = lambda *_: True # Always default on
logger.debug(f"Resolved engine args: {engine_args}")
# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)

vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
workers = worker_procs.spawn(
"vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp
)
# if isinstance(engine_args, Mapping):
# engine_args = EngineArgs(**engine_args)
# engine_args._is_v1_supported_oracle = lambda *_: True # Always default on
# logger.debug(f"Resolved engine args: {engine_args}")

if isinstance(sampling_params, Mapping):
sampling_params = SamplingParams.from_optional(**sampling_params)
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
logger.debug(f"Resolved sampling params: {sampling_params}")
# vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
# engine_args = kwargs["engine_args"]
# if isinstance(engine_args, Mapping):
# engine_args = EngineArgs(**engine_args)
# engine_args._is_v1_supported_oracle = lambda *_: True # Always default on
# vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
worker = worker_procs.spawn("vllm_worker", PolicyWorker, *args, **kwargs)

# if isinstance(sampling_params, Mapping):
# sampling_params = SamplingParams.from_optional(**sampling_params)
# sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
# logger.debug(f"Resolved sampling params: {sampling_params}")

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
policy = policy_proc.spawn(
actor_name,
cls,
engine_args=engine_args,
sampling_params=sampling_params,
available_devices=available_devices,
policy_worker=workers,
*args,
**kwargs,
)
policy._policy_proc = policy_proc
policy._worker_procs = worker_procs
await policy.register_worker.call(worker)
await policy.setup.call()
return policy

@endpoint
async def setup(self):
"""Mirrors the __init__ of vLLM's LLMEngine."""
if self.policy_worker is None:
if self.worker is None:
raise RuntimeError(
"Policy worker should not be None. Usually it would be attached to Policy in the ``launch`` method."
)
await self.policy_worker.setup.call()
await self.worker.setup.call()

self.request_id = 0
self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {}
Expand All @@ -203,11 +199,6 @@ async def setup(self):
self.request_lock = asyncio.Condition() # Guard for accepting_requests
self.update_lock = asyncio.Condition() # Guard for updating requests

vllm_config: VllmConfig = self.engine_args.create_engine_config(
UsageContext.LLM_CLASS
)
self.max_model_len = vllm_config.model_config.max_model_len

# Setup processors
# TODO: move all processing to the Environment
# TODO: add support for `log_stats` and `mm_registry`
Expand All @@ -222,7 +213,7 @@ async def setup(self):
self.output_processor = OutputProcessor(tokenizer, log_stats=None)

# Configure KV caches
kv_cache_configs = await self.policy_worker.setup_kv_cache.call()
kv_cache_configs = await self.worker.setup_kv_cache.call()
_, kv_cache_config = next(kv_cache_configs.items())
vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
vllm_config.cache_config.num_cpu_blocks = 0
Expand Down Expand Up @@ -261,7 +252,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
self.request_id += 1 % sys.maxsize
request_id = str(self.request_id)

tokenization_kwargs = self.tokenization_kwargs or {}
tokenization_kwargs = {}
# TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens
_validate_truncation_size(
Expand All @@ -274,7 +265,6 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
prompt={"prompt": prompt},
params=self.sampling_params,
arrival_time=None,
lora_request=self.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=None,
priority=priority,
Expand Down Expand Up @@ -341,8 +331,9 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
def _preprocess_add_request(
self, request: EngineCoreRequest
) -> tuple[Request, int]:
""" (forge/issues/332) Will require attention when we bump vllm versions
https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419"""
"""(forge/issues/332) Will require attention when we bump vllm versions
https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419
"""
if request.mm_hashes is not None:
raise NotImplementedError("Support for mm_hash is not implemented yet.")
req = Request.from_engine_core_request(request)
Expand All @@ -358,9 +349,7 @@ async def run(self) -> None:
self.running = True
while self.running:
scheduler_output = self.scheduler.schedule()
worker_outputs = await self.policy_worker.execute_model.call(
scheduler_output
)
worker_outputs = await self.worker.execute_model.call(scheduler_output)

# The results of `execute_model` are gathered on the driver rank (rank 0)
_, worker_output = next(worker_outputs.items())
Expand Down Expand Up @@ -427,8 +416,8 @@ async def update_weights(self, policy_version: int) -> None:
record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
# Call update_weights on every policy_worker
await self.policy_worker.update_weights.call(policy_version)
# Call update_weights on every policy worker
await self.worker.update_weights.call(policy_version)
self.policy_version = policy_version

# After updating the weights, we need to reset the KV cache
Expand Down Expand Up @@ -507,13 +496,16 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[Policy] save model parameters for testing.")
await self.policy_worker._test_save_model_params.call()
await self.worker._test_save_model_params.call()

@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[Policy] start validating model parameters.")
return await self.policy_worker._test_validate_model_params.call(validate_fn)
return await self.worker._test_validate_model_params.call(validate_fn)


from typing import Any


@dataclass
Expand All @@ -525,20 +517,31 @@ class PolicyWorker(ForgeActor):
the creation and invocation of all PolicyWorkers.
"""

vllm_config: VllmConfig
state_dict_key: str = "model_state_dict"
# TODO: remove this later since no plumbing exists to change this value.
# Also, whether to use dcp or not can be inferred from torchstore get() call.
use_dcp: bool = True

# used for tesing purposes only
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
use_dcp: bool | None = None
# TODO: Remove below param
_test_prev_params = {}

def __post_init__(self):
super().__init__()
if isinstance(self.engine_args, Mapping):
self.engine_args = EngineArgs(**self.engine_args)
self.engine_args._is_v1_supported_oracle = lambda *_: True
# Note: vllm_config creation is deferred to setup() method to avoid
# model inspection issues during remote actor initialization
self.vllm_config = None
print("HELLLOOOO")

if self.use_dcp is None:
self.use_dcp = TORCHSTORE_USE_RDMA.get_value() == 0

@endpoint
async def setup(self):
# Create vllm_config here instead of during initialization to avoid
# model inspection issues during remote actor initialization
self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)

self.rank = current_rank().rank
os.environ["RANK"] = str(self.rank)
parallel_config = self.vllm_config.parallel_config
Expand Down
Loading