diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index fd9f11482..b9e850bc7 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -16,6 +16,7 @@ import torchstore as ts from monarch.actor import current_rank, endpoint, ProcMesh from torchstore.state_dict_utils import DELIM +from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.utils import _validate_truncation_size @@ -62,6 +63,8 @@ class SamplingConfig: n: int = 1 guided_decoding: bool = False max_tokens: int = 512 + temperature: float = 1.0 + top_p: float = 1.0 def __post_init__(self): gd_params = None @@ -89,6 +92,9 @@ class EngineConfig(EngineArgs): pipeline_parallel_size: int = 1 enforce_eager: bool = False + # Original method returns False when not run in the main thread + _is_v1_supported_oracle = lambda *_: True + @classmethod def from_dict(cls, d: Mapping): d = dict(d) @@ -96,6 +102,11 @@ def from_dict(cls, d: Mapping): valid_args = {k: v for k, v in d.items() if k in all_fields} return cls(**valid_args) + def create_vllm_config(self) -> VllmConfig: + # This is not a typo: EngineArgs.create_engine_config + # creates a VllmConfig + return self.create_engine_config(UsageContext.LLM_CLASS) + @dataclass class Policy(PolicyInterface): @@ -143,13 +154,14 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] if isinstance(engine_config, Mapping): engine_config = EngineConfig.from_dict(engine_config) - if isinstance(engine_config, Mapping): - sampling_config = SamplingConfig(**sampling_config) - + vllm_config = engine_config.create_vllm_config() workers = await worker_procs.spawn( - "vllm_worker", PolicyWorker, vllm_args=engine_config + "vllm_worker", PolicyWorker, vllm_config=vllm_config ) + if isinstance(sampling_config, Mapping): + sampling_config = SamplingConfig(**sampling_config) + # TODO - expand support so name can stick within kwargs actor_name = kwargs.pop("name", cls.__name__) policy = await policy_proc.spawn( @@ -189,24 +201,24 @@ async def setup(self): await self.policy_worker.setup.call() self.request_id = 0 - self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} - self.vllm_args = await self.policy_worker.get_vllm_args.choose() + self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} + self.vllm_config: VllmConfig = self.engine_config.create_vllm_config() # Setup sampling params self.sampling_params = get_default_sampling_params( - self.vllm_args, overrides=asdict(self.sampling_config) + self.vllm_config, overrides=asdict(self.sampling_config) ) # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_args.model_config, - scheduler_config=self.vllm_args.scheduler_config, - lora_config=self.vllm_args.lora_config, + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, ) self.processor = Processor( - vllm_config=self.vllm_args, tokenizer=tokenizer, mm_registry=None + vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None ) self.output_processor = OutputProcessor(tokenizer, log_stats=None) @@ -214,12 +226,12 @@ async def setup(self): # TODO: Add support for `log_stats` kv_cache_configs = await self.policy_worker.setup_kv_cache.call() kv_cache_config = kv_cache_configs._values[0] - self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks - self.vllm_args.cache_config.num_cpu_blocks = 0 + self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + self.vllm_config.cache_config.num_cpu_blocks = 0 - structured_output_manager = StructuredOutputManager(self.vllm_args) + structured_output_manager = StructuredOutputManager(self.vllm_config) self.scheduler = Scheduler( - vllm_config=self.vllm_args, + vllm_config=self.vllm_config, kv_cache_config=kv_cache_config, structured_output_manager=structured_output_manager, include_finished_set=False, @@ -254,7 +266,7 @@ async def generate(self, prompt: str, priority: int = 0) -> RequestOutput: # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens _validate_truncation_size( - self.vllm_args.model_config.max_model_len, + self.vllm_config.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs, ) @@ -316,7 +328,7 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i async def run(self): # TODO: add support for `iteration_stats` # TODO: move postprocessing out of loop to not block - parallel_config = self.vllm_args.parallel_config + parallel_config = self.vllm_config.parallel_config output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size self.running = True while self.running: @@ -371,30 +383,9 @@ async def stop(self): @dataclass class PolicyWorker(ForgeActor): - vllm_args: EngineConfig | Mapping = EngineConfig() + vllm_config: VllmConfig state_dict_key: str = "model_state_dict" - def __post_init__(self): - """Build vLLM Arguments - - vLLM specific TODOS - - output format - - check_health - - _aggregate workers output - - register_failure_callback - - Testing - - all LLM generate methods, verify against LLM inputs - - all executor methods verify no changes - """ - if isinstance(self.vllm_args, Mapping): - self.vllm_args = EngineConfig.from_dict(self.vllm_args) - - # Original method returns False when not run in the main thread - self.vllm_args._is_v1_supported_oracle = lambda *_: True - # Build Config - self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) - @endpoint async def setup(self): # TODO: remove ["gpus"] when monarch implements a flat rank @@ -412,7 +403,7 @@ async def _load_tensor_parallel_state_dict( Load full state dict from torchstore into tensor parallel model with deterministic sharding. """ sharding = VLLMSharding( - self.vllm_args.parallel_config.tensor_parallel_size, self.rank + self.vllm_config.parallel_config.tensor_parallel_size, self.rank ) for param_name in current_state_dict.keys(): @@ -455,16 +446,16 @@ async def setup_kv_cache(self): # Get the kv cache tensor size kv_cache_config = get_kv_cache_config( - self.vllm_args, kv_cache_spec, available_gpu_memory + self.vllm_config, kv_cache_spec, available_gpu_memory ) # TODO: unify configs across TorchStore # unify_kv_cache_configs(kv_cache_configs) - self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks - self.vllm_args.cache_config.num_cpu_blocks = 0 + self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + self.vllm_config.cache_config.num_cpu_blocks = 0 # Initialize kv cache and warmup the execution: # from multiproc_executor.py:MultiprocExecutor.initialize_from_config - kv_cache_configs = [None] * self.vllm_args.parallel_config.world_size + kv_cache_configs = [None] * self.vllm_config.parallel_config.world_size kv_cache_configs[self.rank] = kv_cache_config self.worker.initialize_from_config(kv_cache_configs) self.worker.compile_or_warm_up_model() @@ -472,8 +463,8 @@ async def setup_kv_cache(self): return kv_cache_config @endpoint - async def get_vllm_args(self): - return self.vllm_args + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config @endpoint async def _get_model_params(self) -> dict[str, torch.Tensor]: @@ -488,7 +479,7 @@ async def _get_model_params(self) -> dict[str, torch.Tensor]: def setup_worker(self): """Build and Instantiate vLLM worker""" - parallel_config = self.vllm_args.parallel_config + parallel_config = self.vllm_config.parallel_config set_multiprocessing_worker_envs(parallel_config) ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT") distributed_init_method = get_distributed_init_method(ip, port) @@ -496,13 +487,13 @@ def setup_worker(self): local_rank = self.rank % torch.accelerator.device_count() is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0 all_kwargs[self.rank] = { - "vllm_config": self.vllm_args, + "vllm_config": self.vllm_config, "local_rank": local_rank, "rank": self.rank, "distributed_init_method": distributed_init_method, "is_driver_worker": is_driver_worker, } - worker = WorkerWrapperBase(self.vllm_args, self.rank) + worker = WorkerWrapperBase(self.vllm_config, self.rank) worker.init_worker(all_kwargs) worker.init_device() worker.load_model() diff --git a/tests/unit_tests/test_policy_config.py b/tests/unit_tests/test_policy_config.py index 2ec787b03..08de4f907 100644 --- a/tests/unit_tests/test_policy_config.py +++ b/tests/unit_tests/test_policy_config.py @@ -44,6 +44,7 @@ def test_policy_default_initialization(self): self.assertEqual(policy.engine_config.tensor_parallel_size, 1) self.assertEqual(policy.engine_config.pipeline_parallel_size, 1) self.assertFalse(policy.engine_config.enforce_eager) + self.assertTrue(policy.engine_config._is_v1_supported_oracle()) # Sampling defaults self.assertEqual(policy.sampling_config.n, 1) @@ -91,6 +92,7 @@ def test_policy_with_dict_configs(self): self.assertEqual(policy.engine_config.tensor_parallel_size, 7777) self.assertEqual(policy.engine_config.pipeline_parallel_size, 8888) self.assertTrue(policy.engine_config.enforce_eager) + self.assertTrue(policy.engine_config._is_v1_supported_oracle()) self.assertEqual(policy.sampling_config.n, 1357) # After __post_init__, guided_decoding becomes GuidedDecodingParams object when True @@ -143,6 +145,7 @@ def test_policy_yaml_config_loading(self): self.assertEqual(policy.engine_config.tensor_parallel_size, 1234) self.assertEqual(policy.engine_config.pipeline_parallel_size, 5678) self.assertTrue(policy.engine_config.enforce_eager) + self.assertTrue(policy.engine_config._is_v1_supported_oracle()) self.assertEqual(policy.sampling_config.n, 2468) # After __post_init__, guided_decoding becomes GuidedDecodingParams object when True