diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 8f8cf8fc7..ca934127e 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -22,7 +22,6 @@ from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.utils import _validate_truncation_size from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs -from vllm.lora.request import LoRARequest from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs @@ -53,7 +52,6 @@ from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from forge.env import TORCHSTORE_USE_RDMA -from forge.interfaces import Policy as GeneratorInterface from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig @@ -63,8 +61,8 @@ @dataclass -class Generator(GeneratorInterface): - """Instance of a vLLM-based Generator. +class Generator(ForgeActor): + """Instance of a vLLM-based generator. This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The main difference is that all communications are controlled here via Monarch's proc meshes. @@ -72,11 +70,10 @@ class Generator(GeneratorInterface): Args: engine_args (EngineArgs): The engine arguments to use for the vLLM engine. sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine. - available_devices (str): The available devices to use for the vLLM engine. - use_dcp (bool): Whether to use DCP for NFS-based weight sync. + use_dcp_for_weight_sync (bool): Whether to use DCP for NFS-based weight sync. Default depends on + whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled. Example: - >>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service( ... engine_args=EngineArgs(...), ... sampling_params=SamplingParams(...), @@ -89,50 +86,50 @@ class Generator(GeneratorInterface): engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) - available_devices: str | None = None - use_dcp: bool = ( - TORCHSTORE_USE_RDMA.get_value() == 0 - ) # torchstore currently only accepts 0 or 1 - # Remaining variables are initialized in self.setup() - lora_request: LoRARequest | None = None - tokenization_kwargs: dict = field(default_factory=dict) - generator_worker: GeneratorWorker | None = None + use_dcp_for_weight_sync: bool | None = None def __post_init__(self): super().__init__() self._run_task: asyncio.Task | None = None self._generator_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None + self.worker: GeneratorWorker | None = None self.running = False self.generator_version: int = 0 if isinstance(self.engine_args, Mapping): self.engine_args = EngineArgs(**self.engine_args) self.engine_args._is_v1_supported_oracle = lambda *_: True + self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS) if isinstance(self.sampling_params, Mapping): self.sampling_params = SamplingParams.from_optional(**self.sampling_params) self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + if self.use_dcp_for_weight_sync is None: + self.use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA.get_value() + logger.debug(f"{self.use_dcp_for_weight_sync=}") + + @endpoint + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config + + @endpoint + async def register_worker(self, worker: GeneratorWorker) -> None: + self.worker = worker + logger.debug("Registered GeneratorWorker on Generator.") + @classmethod async def launch( # pyright: ignore[reportIncompatibleMethodOverride] cls: type["Generator"], - *, - engine_args: EngineArgs | Mapping = EngineArgs(), - sampling_params: SamplingParams | Mapping = SamplingParams(), - available_devices: str | None = None, - use_dcp: bool = ( - TORCHSTORE_USE_RDMA.get_value() == 0 - ), # torchstore currently only accepts 0 or 1 + *args, **kwargs, ) -> "Generator": - """Launch the Generator with its workers. + """Custom launch for the Generator service with its workers. We overwrite the default Service launch method in order to setup Actors (GeneratorWorker) within this "coordinating" Actor. We first create a proc_mesh for the workers, then a proc_mesh for the generator, and then we spawn the workers and the generator in setup. - - The args here generally should match those in the `__init__` method of the Generator class. """ # Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES process_config: ProcessConfig = ProcessConfig( @@ -141,60 +138,46 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] with_gpus=cls.with_gpus, mesh_name=cls.mesh_name, ) - worker_procs = await get_proc_mesh(process_config=process_config) # TODO - issues/144 we will want to ensure colocation with workers # We're currently locating the Generator on the local host proc mesh # vLLM initialization without setting env variables at proc_mesh creation - # level leads to issues. - # Once we can create multiple proc meshes on a host mesh, we can ensure - # host colocation + # level leads to issues. Once we can create multiple proc meshes on a host mesh, + # we can ensure host colocation generator_proc_config = copy(process_config) generator_proc_config.procs = 1 generator_proc_config.hosts = None generator_proc_config.with_gpus = False generator_proc = await get_proc_mesh(process_config=generator_proc_config) - if isinstance(engine_args, Mapping): - engine_args = EngineArgs(**engine_args) - engine_args._is_v1_supported_oracle = lambda *_: True # Always default on - logger.debug(f"Resolved engine args: {engine_args}") - - vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) - workers = worker_procs.spawn( - "vllm_worker", GeneratorWorker, vllm_config=vllm_config, use_dcp=use_dcp - ) - - if isinstance(sampling_params, Mapping): - sampling_params = SamplingParams.from_optional(**sampling_params) - sampling_params.output_kind = RequestOutputKind.FINAL_ONLY - logger.debug(f"Resolved sampling params: {sampling_params}") - # TODO - expand support so name can stick within kwargs actor_name = kwargs.pop("name", cls.__name__) generator = generator_proc.spawn( actor_name, cls, - engine_args=engine_args, - sampling_params=sampling_params, - available_devices=available_devices, - generator_worker=workers, + *args, **kwargs, ) + + worker_procs = await get_proc_mesh(process_config=process_config) + vllm_config = ( + await generator.get_vllm_config.call_one() + ) # Config should be the same across all actors + worker = worker_procs.spawn( + "vllm_worker", GeneratorWorker, vllm_config=vllm_config + ) + await worker.setup.call() + await generator.register_worker.call(worker) + generator._generator_proc = generator_proc generator._worker_procs = worker_procs await generator.setup.call() + return generator @endpoint async def setup(self): """Mirrors the __init__ of vLLM's LLMEngine.""" - if self.generator_worker is None: - raise RuntimeError( - "Geneator worker should not be None. Usually it would be attached to Generator in the ``launch`` method." - ) - await self.generator_worker.setup.call() - self.request_id = 0 self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {} @@ -204,35 +187,30 @@ async def setup(self): self.request_lock = asyncio.Condition() # Guard for accepting_requests self.update_lock = asyncio.Condition() # Guard for updating requests - vllm_config: VllmConfig = self.engine_args.create_engine_config( - UsageContext.LLM_CLASS - ) - self.max_model_len = vllm_config.model_config.max_model_len - # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config, + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, ) self.processor = Processor( - vllm_config=vllm_config, tokenizer=tokenizer, mm_registry=None + vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None ) self.output_processor = OutputProcessor(tokenizer, log_stats=None) # Configure KV caches - kv_cache_configs = await self.generator_worker.setup_kv_cache.call() + kv_cache_configs = await self.worker.setup_kv_cache.call() _, kv_cache_config = next(kv_cache_configs.items()) - vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks - vllm_config.cache_config.num_cpu_blocks = 0 + self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + self.vllm_config.cache_config.num_cpu_blocks = 0 # Setup scheduler # TODO: Add support for `log_stats` - structured_output_manager = StructuredOutputManager(vllm_config) + structured_output_manager = StructuredOutputManager(self.vllm_config) self.scheduler = Scheduler( - vllm_config=vllm_config, + vllm_config=self.vllm_config, kv_cache_config=kv_cache_config, structured_output_manager=structured_output_manager, include_finished_set=False, @@ -262,11 +240,11 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: self.request_id += 1 % sys.maxsize request_id = str(self.request_id) - tokenization_kwargs = self.tokenization_kwargs or {} + tokenization_kwargs = {} # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens _validate_truncation_size( - self.max_model_len, + self.vllm_config.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs, ) @@ -275,7 +253,6 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: prompt={"prompt": prompt}, params=self.sampling_params, arrival_time=None, - lora_request=self.lora_request, tokenization_kwargs=tokenization_kwargs, trace_headers=None, priority=priority, @@ -360,9 +337,7 @@ async def run(self) -> None: self.running = True while self.running: scheduler_output = self.scheduler.schedule() - worker_outputs = await self.generator_worker.execute_model.call( - scheduler_output - ) + worker_outputs = await self.worker.execute_model.call(scheduler_output) # The results of `execute_model` are gathered on the driver rank (rank 0) _, worker_output = next(worker_outputs.items()) @@ -431,8 +406,8 @@ async def update_weights(self, version: int) -> None: ) logger.debug(f"Starting weight update on {self.__class__.__name__}") - # Call update_weights on every generator_worker - await self.generator_worker.update_weights.call(version=version) + # Call update_weights on every generator worker + await self.worker.update_weights.call(version=version) self.generator_version = version # After updating the weights, we need to reset the KV cache @@ -511,13 +486,13 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only.""" logger.info("[Generator] save model parameters for testing.") - await self.generator_worker._test_save_model_params.call() + await self.worker._test_save_model_params.call() @endpoint async def _test_validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" logger.info("[Generator] start validating model parameters.") - return await self.generator_worker._test_validate_model_params.call(validate_fn) + return await self.worker._test_validate_model_params.call(validate_fn) @dataclass @@ -530,17 +505,9 @@ class GeneratorWorker(ForgeActor): """ vllm_config: VllmConfig - state_dict_key: str = "model_state_dict" - # TODO: remove this later since no plumbing exists to change this value. - # Also, whether to use dcp or not can be inferred from torchstore get() call. - use_dcp: bool = True - - # used for tesing purposes only + # TODO: Remove below param _test_prev_params = {} - def __post_init__(self): - super().__init__() - @endpoint async def setup(self): self.rank = current_rank().rank @@ -602,11 +569,12 @@ async def update_weights(self, version: int) -> None: prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys loaded_weights = set() t = Tracer("worker_perf/update_weights", timer="gpu") t.start() - # Entire state dict is stored in a single DCP handle - if dcp_whole_state_dict_key in matching_keys: + + if use_dcp_for_weight_sync: dcp_handle = await ts.get(dcp_whole_state_dict_key) hf_param_names = dcp_handle.param_names for name in hf_param_names: @@ -614,7 +582,7 @@ async def update_weights(self, version: int) -> None: loaded = model.load_weights([(name, param)]) del param loaded_weights.update(loaded) - else: # Load each parameter from torchstore directly without DCP + else: hf_param_names = [extract_param_name(key) for key in matching_keys] # We can't pass a generator since vllm load_weights is not async. # Instead, we just call load_weights with one parameter at a time. @@ -624,6 +592,7 @@ async def update_weights(self, version: int) -> None: loaded = model.load_weights([(name, param)]) del param loaded_weights.update(loaded) + t.stop() @endpoint diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py index 1c64e42e2..94cb58859 100644 --- a/tests/unit_tests/test_generator_config.py +++ b/tests/unit_tests/test_generator_config.py @@ -39,7 +39,6 @@ def test_generator_default_initialization(self): # Default factories self.assertIsInstance(generator.engine_args, EngineArgs) self.assertIsInstance(generator.sampling_params, SamplingParams) - self.assertIsNone(generator.available_devices) # Worker defaults self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B") @@ -58,46 +57,43 @@ def test_generator_default_initialization(self): reason="Import error, likely due to missing dependencies on CI.", ) def test_generator_with_dict_configs(self): - """Generator accepts dicts for engine_config and sampling_config, including nested dicts.""" from forge.actors.generator import Generator from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams - # Test with nested dict structure engine_dict = { - "model": "test-model-6789", - "tensor_parallel_size": 7777, - "pipeline_parallel_size": 8888, + "model": "Qwen/Qwen3-0.6B", + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, "enforce_eager": True, - "gpu_memory_utilization": 0.9, - "max_model_len": 4096, + "gpu_memory_utilization": 0.1, + "max_model_len": 1024, } sampling_dict = { - "n": 1357, - "max_tokens": 2468, + "n": 2, + "max_tokens": 32, } generator = Generator( engine_args=engine_dict, sampling_params=sampling_dict, - available_devices="test-gpu-device-abcd", ) self.assertIsInstance(generator.engine_args, EngineArgs) self.assertIsInstance(generator.sampling_params, SamplingParams) # Test basic fields - self.assertEqual(generator.engine_args.model, "test-model-6789") - self.assertEqual(generator.engine_args.tensor_parallel_size, 7777) - self.assertEqual(generator.engine_args.pipeline_parallel_size, 8888) - self.assertEqual(generator.engine_args.gpu_memory_utilization, 0.9) - self.assertEqual(generator.engine_args.max_model_len, 4096) + self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B") + self.assertEqual(generator.engine_args.tensor_parallel_size, 1) + self.assertEqual(generator.engine_args.pipeline_parallel_size, 1) + self.assertEqual(generator.engine_args.gpu_memory_utilization, 0.1) + self.assertEqual(generator.engine_args.max_model_len, 1024) self.assertTrue(generator.engine_args.enforce_eager) self.assertTrue(generator.engine_args._is_v1_supported_oracle()) - self.assertEqual(generator.sampling_params.n, 1357) - self.assertEqual(generator.sampling_params.max_tokens, 2468) + self.assertEqual(generator.sampling_params.n, 2) + self.assertEqual(generator.sampling_params.max_tokens, 32) @pytest.mark.skipif( _import_error(), @@ -109,16 +105,14 @@ def test_generator_yaml_config_loading(self): yaml_content = """ engine_args: - model: "yaml-test-model-9876" - tensor_parallel_size: 1234 - pipeline_parallel_size: 5678 + model: "Qwen/Qwen3-0.6B" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 enforce_eager: true sampling_params: - n: 2468 - max_tokens: 1357 - - available_devices: "yaml-test-device-xyz" + n: 2 + max_tokens: 32 """ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: @@ -129,16 +123,14 @@ def test_generator_yaml_config_loading(self): config = yaml.safe_load(yaml_file) generator = Generator(**config) - self.assertEqual(generator.engine_args.model, "yaml-test-model-9876") - self.assertEqual(generator.engine_args.tensor_parallel_size, 1234) - self.assertEqual(generator.engine_args.pipeline_parallel_size, 5678) + self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B") + self.assertEqual(generator.engine_args.tensor_parallel_size, 1) + self.assertEqual(generator.engine_args.pipeline_parallel_size, 1) self.assertTrue(generator.engine_args.enforce_eager) self.assertTrue(generator.engine_args._is_v1_supported_oracle()) - self.assertEqual(generator.sampling_params.n, 2468) - self.assertEqual(generator.sampling_params.max_tokens, 1357) - - self.assertEqual(generator.available_devices, "yaml-test-device-xyz") + self.assertEqual(generator.sampling_params.n, 2) + self.assertEqual(generator.sampling_params.max_tokens, 32) if __name__ == "__main__":