Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 60 additions & 91 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
Expand Down Expand Up @@ -53,7 +52,6 @@
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt
from forge.env import TORCHSTORE_USE_RDMA
from forge.interfaces import Policy as GeneratorInterface
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig
Expand All @@ -63,20 +61,19 @@


@dataclass
class Generator(GeneratorInterface):
"""Instance of a vLLM-based Generator.
class Generator(ForgeActor):
"""Instance of a vLLM-based generator.

This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The
main difference is that all communications are controlled here via Monarch's proc meshes.

Args:
engine_args (EngineArgs): The engine arguments to use for the vLLM engine.
sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine.
available_devices (str): The available devices to use for the vLLM engine.
use_dcp (bool): Whether to use DCP for NFS-based weight sync.
use_dcp_for_weight_sync (bool): Whether to use DCP for NFS-based weight sync. Default depends on
whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled.

Example:

>>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service(
... engine_args=EngineArgs(...),
... sampling_params=SamplingParams(...),
Expand All @@ -89,50 +86,50 @@ class Generator(GeneratorInterface):

engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
available_devices: str | None = None
use_dcp: bool = (
TORCHSTORE_USE_RDMA.get_value() == 0
) # torchstore currently only accepts 0 or 1
# Remaining variables are initialized in self.setup()
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
generator_worker: GeneratorWorker | None = None
use_dcp_for_weight_sync: bool | None = None

def __post_init__(self):
super().__init__()
self._run_task: asyncio.Task | None = None
self._generator_proc: ProcMesh | None = None
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.worker: GeneratorWorker | None = None
self.running = False
self.generator_version: int = 0

if isinstance(self.engine_args, Mapping):
self.engine_args = EngineArgs(**self.engine_args)
self.engine_args._is_v1_supported_oracle = lambda *_: True
self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)

if isinstance(self.sampling_params, Mapping):
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY

if self.use_dcp_for_weight_sync is None:
self.use_dcp_for_weight_sync = TORCHSTORE_USE_RDMA.get_value() == 0
logger.debug(f"{self.use_dcp_for_weight_sync=}")

@endpoint
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config

@endpoint
async def register_worker(self, worker: GeneratorWorker) -> None:
self.worker = worker
logger.debug("Registered GeneratorWorker on Generator.")

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Generator"],
*,
engine_args: EngineArgs | Mapping = EngineArgs(),
sampling_params: SamplingParams | Mapping = SamplingParams(),
available_devices: str | None = None,
use_dcp: bool = (
TORCHSTORE_USE_RDMA.get_value() == 0
), # torchstore currently only accepts 0 or 1
*args,
**kwargs,
) -> "Generator":
"""Launch the Generator with its workers.
"""Custom launch for the Generator service with its workers.

We overwrite the default Service launch method in order to setup Actors (GeneratorWorker) within this "coordinating" Actor.
We first create a proc_mesh for the workers, then a proc_mesh for the generator, and then we spawn the workers
and the generator in setup.

The args here generally should match those in the `__init__` method of the Generator class.
"""
# Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
process_config: ProcessConfig = ProcessConfig(
Expand All @@ -141,60 +138,46 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
with_gpus=cls.with_gpus,
mesh_name=cls.mesh_name,
)
worker_procs = await get_proc_mesh(process_config=process_config)

# TODO - issues/144 we will want to ensure colocation with workers
# We're currently locating the Generator on the local host proc mesh
# vLLM initialization without setting env variables at proc_mesh creation
# level leads to issues.
# Once we can create multiple proc meshes on a host mesh, we can ensure
# host colocation
# level leads to issues. Once we can create multiple proc meshes on a host mesh,
# we can ensure host colocation
generator_proc_config = copy(process_config)
generator_proc_config.procs = 1
generator_proc_config.hosts = None
generator_proc_config.with_gpus = False
generator_proc = await get_proc_mesh(process_config=generator_proc_config)

if isinstance(engine_args, Mapping):
engine_args = EngineArgs(**engine_args)
engine_args._is_v1_supported_oracle = lambda *_: True # Always default on
logger.debug(f"Resolved engine args: {engine_args}")

vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
workers = worker_procs.spawn(
"vllm_worker", GeneratorWorker, vllm_config=vllm_config, use_dcp=use_dcp
)

if isinstance(sampling_params, Mapping):
sampling_params = SamplingParams.from_optional(**sampling_params)
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
logger.debug(f"Resolved sampling params: {sampling_params}")

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
generator = generator_proc.spawn(
actor_name,
cls,
engine_args=engine_args,
sampling_params=sampling_params,
available_devices=available_devices,
generator_worker=workers,
*args,
**kwargs,
)

worker_procs = await get_proc_mesh(process_config=process_config)
vllm_config = (
await generator.get_vllm_config.call_one()
) # Config should be the same across all actors
worker = worker_procs.spawn(
"vllm_worker", GeneratorWorker, vllm_config=vllm_config
)
await worker.setup.call()
await generator.register_worker.call(worker)

generator._generator_proc = generator_proc
generator._worker_procs = worker_procs
await generator.setup.call()

return generator

@endpoint
async def setup(self):
"""Mirrors the __init__ of vLLM's LLMEngine."""
if self.generator_worker is None:
raise RuntimeError(
"Geneator worker should not be None. Usually it would be attached to Generator in the ``launch`` method."
)
await self.generator_worker.setup.call()

self.request_id = 0
self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {}

Expand All @@ -204,35 +187,30 @@ async def setup(self):
self.request_lock = asyncio.Condition() # Guard for accepting_requests
self.update_lock = asyncio.Condition() # Guard for updating requests

vllm_config: VllmConfig = self.engine_args.create_engine_config(
UsageContext.LLM_CLASS
)
self.max_model_len = vllm_config.model_config.max_model_len

# Setup processors
# TODO: move all processing to the Environment
# TODO: add support for `log_stats` and `mm_registry`
tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config,
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
)
self.processor = Processor(
vllm_config=vllm_config, tokenizer=tokenizer, mm_registry=None
vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None
)
self.output_processor = OutputProcessor(tokenizer, log_stats=None)

# Configure KV caches
kv_cache_configs = await self.generator_worker.setup_kv_cache.call()
kv_cache_configs = await self.worker.setup_kv_cache.call()
_, kv_cache_config = next(kv_cache_configs.items())
vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
vllm_config.cache_config.num_cpu_blocks = 0
self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
self.vllm_config.cache_config.num_cpu_blocks = 0

# Setup scheduler
# TODO: Add support for `log_stats`
structured_output_manager = StructuredOutputManager(vllm_config)
structured_output_manager = StructuredOutputManager(self.vllm_config)
self.scheduler = Scheduler(
vllm_config=vllm_config,
vllm_config=self.vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=structured_output_manager,
include_finished_set=False,
Expand Down Expand Up @@ -262,11 +240,11 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
self.request_id += 1 % sys.maxsize
request_id = str(self.request_id)

tokenization_kwargs = self.tokenization_kwargs or {}
tokenization_kwargs = {}
# TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens
_validate_truncation_size(
self.max_model_len,
self.vllm_config.model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs,
)
Expand All @@ -275,7 +253,6 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
prompt={"prompt": prompt},
params=self.sampling_params,
arrival_time=None,
lora_request=self.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=None,
priority=priority,
Expand Down Expand Up @@ -360,9 +337,7 @@ async def run(self) -> None:
self.running = True
while self.running:
scheduler_output = self.scheduler.schedule()
worker_outputs = await self.generator_worker.execute_model.call(
scheduler_output
)
worker_outputs = await self.worker.execute_model.call(scheduler_output)

# The results of `execute_model` are gathered on the driver rank (rank 0)
_, worker_output = next(worker_outputs.items())
Expand Down Expand Up @@ -431,8 +406,8 @@ async def update_weights(self, version: int) -> None:
)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
# Call update_weights on every generator_worker
await self.generator_worker.update_weights.call(version=version)
# Call update_weights on every generator worker
await self.worker.update_weights.call(version=version)
self.generator_version = version

# After updating the weights, we need to reset the KV cache
Expand Down Expand Up @@ -511,13 +486,13 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[Generator] save model parameters for testing.")
await self.generator_worker._test_save_model_params.call()
await self.worker._test_save_model_params.call()

@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[Generator] start validating model parameters.")
return await self.generator_worker._test_validate_model_params.call(validate_fn)
return await self.worker._test_validate_model_params.call(validate_fn)


@dataclass
Expand All @@ -530,17 +505,9 @@ class GeneratorWorker(ForgeActor):
"""

vllm_config: VllmConfig
state_dict_key: str = "model_state_dict"
# TODO: remove this later since no plumbing exists to change this value.
# Also, whether to use dcp or not can be inferred from torchstore get() call.
use_dcp: bool = True

# used for tesing purposes only
# TODO: Remove below param
_test_prev_params = {}

def __post_init__(self):
super().__init__()

@endpoint
async def setup(self):
self.rank = current_rank().rank
Expand Down Expand Up @@ -602,19 +569,20 @@ async def update_weights(self, version: int) -> None:
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
loaded_weights = set()
t = Tracer("worker_perf/update_weights", timer="gpu")
t.start()
# Entire state dict is stored in a single DCP handle
if dcp_whole_state_dict_key in matching_keys:

if use_dcp_for_weight_sync:
dcp_handle = await ts.get(dcp_whole_state_dict_key)
hf_param_names = dcp_handle.param_names
for name in hf_param_names:
param = load_tensor_from_dcp(dcp_handle, name)
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
else: # Load each parameter from torchstore directly without DCP
else:
hf_param_names = [extract_param_name(key) for key in matching_keys]
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
Expand All @@ -624,6 +592,7 @@ async def update_weights(self, version: int) -> None:
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)

t.stop()

@endpoint
Expand Down
Loading