Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 40 additions & 49 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torchstore as ts
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore.state_dict_utils import DELIM
from vllm.config import VllmConfig

from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.utils import _validate_truncation_size
Expand Down Expand Up @@ -62,6 +63,8 @@ class SamplingConfig:
n: int = 1
guided_decoding: bool = False
max_tokens: int = 512
temperature: float = 1.0
top_p: float = 1.0

def __post_init__(self):
gd_params = None
Expand Down Expand Up @@ -89,13 +92,21 @@ class EngineConfig(EngineArgs):
pipeline_parallel_size: int = 1
enforce_eager: bool = False

# Original method returns False when not run in the main thread
_is_v1_supported_oracle = lambda *_: True

@classmethod
def from_dict(cls, d: Mapping):
d = dict(d)
all_fields = [f.name for f in fields(cls)]
valid_args = {k: v for k, v in d.items() if k in all_fields}
return cls(**valid_args)

def create_vllm_config(self) -> VllmConfig:
# This is not a typo: EngineArgs.create_engine_config
# creates a VllmConfig
return self.create_engine_config(UsageContext.LLM_CLASS)


@dataclass
class Policy(PolicyInterface):
Expand Down Expand Up @@ -143,13 +154,14 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
if isinstance(engine_config, Mapping):
engine_config = EngineConfig.from_dict(engine_config)

if isinstance(engine_config, Mapping):
sampling_config = SamplingConfig(**sampling_config)

vllm_config = engine_config.create_vllm_config()
workers = await worker_procs.spawn(
"vllm_worker", PolicyWorker, vllm_args=engine_config
"vllm_worker", PolicyWorker, vllm_config=vllm_config
)

if isinstance(sampling_config, Mapping):
sampling_config = SamplingConfig(**sampling_config)

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
policy = await policy_proc.spawn(
Expand Down Expand Up @@ -189,37 +201,37 @@ async def setup(self):
await self.policy_worker.setup.call()

self.request_id = 0
self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
self.vllm_args = await self.policy_worker.get_vllm_args.choose()
self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
self.vllm_config: VllmConfig = self.engine_config.create_vllm_config()

# Setup sampling params
self.sampling_params = get_default_sampling_params(
self.vllm_args, overrides=asdict(self.sampling_config)
self.vllm_config, overrides=asdict(self.sampling_config)
)

# Setup processors
# TODO: move all processing to the Environment
# TODO: add support for `log_stats` and `mm_registry`
tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_args.model_config,
scheduler_config=self.vllm_args.scheduler_config,
lora_config=self.vllm_args.lora_config,
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
)
self.processor = Processor(
vllm_config=self.vllm_args, tokenizer=tokenizer, mm_registry=None
vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None
)
self.output_processor = OutputProcessor(tokenizer, log_stats=None)

# Setup scheduler
# TODO: Add support for `log_stats`
kv_cache_configs = await self.policy_worker.setup_kv_cache.call()
kv_cache_config = kv_cache_configs._values[0]
self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
self.vllm_args.cache_config.num_cpu_blocks = 0
self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
self.vllm_config.cache_config.num_cpu_blocks = 0

structured_output_manager = StructuredOutputManager(self.vllm_args)
structured_output_manager = StructuredOutputManager(self.vllm_config)
self.scheduler = Scheduler(
vllm_config=self.vllm_args,
vllm_config=self.vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=structured_output_manager,
include_finished_set=False,
Expand Down Expand Up @@ -254,7 +266,7 @@ async def generate(self, prompt: str, priority: int = 0) -> RequestOutput:
# TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507
truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens
_validate_truncation_size(
self.vllm_args.model_config.max_model_len,
self.vllm_config.model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs,
)
Expand Down Expand Up @@ -316,7 +328,7 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i
async def run(self):
# TODO: add support for `iteration_stats`
# TODO: move postprocessing out of loop to not block
parallel_config = self.vllm_args.parallel_config
parallel_config = self.vllm_config.parallel_config
output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size
self.running = True
while self.running:
Expand Down Expand Up @@ -371,30 +383,9 @@ async def stop(self):

@dataclass
class PolicyWorker(ForgeActor):
vllm_args: EngineConfig | Mapping = EngineConfig()
vllm_config: VllmConfig
state_dict_key: str = "model_state_dict"

def __post_init__(self):
"""Build vLLM Arguments

vLLM specific TODOS
- output format
- check_health
- _aggregate workers output
- register_failure_callback

Testing
- all LLM generate methods, verify against LLM inputs
- all executor methods verify no changes
"""
if isinstance(self.vllm_args, Mapping):
self.vllm_args = EngineConfig.from_dict(self.vllm_args)

# Original method returns False when not run in the main thread
self.vllm_args._is_v1_supported_oracle = lambda *_: True
# Build Config
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)

@endpoint
async def setup(self):
# TODO: remove ["gpus"] when monarch implements a flat rank
Expand All @@ -412,7 +403,7 @@ async def _load_tensor_parallel_state_dict(
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
"""
sharding = VLLMSharding(
self.vllm_args.parallel_config.tensor_parallel_size, self.rank
self.vllm_config.parallel_config.tensor_parallel_size, self.rank
)

for param_name in current_state_dict.keys():
Expand Down Expand Up @@ -455,25 +446,25 @@ async def setup_kv_cache(self):

# Get the kv cache tensor size
kv_cache_config = get_kv_cache_config(
self.vllm_args, kv_cache_spec, available_gpu_memory
self.vllm_config, kv_cache_spec, available_gpu_memory
)
# TODO: unify configs across TorchStore
# unify_kv_cache_configs(kv_cache_configs)
self.vllm_args.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
self.vllm_args.cache_config.num_cpu_blocks = 0
self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
self.vllm_config.cache_config.num_cpu_blocks = 0

# Initialize kv cache and warmup the execution:
# from multiproc_executor.py:MultiprocExecutor.initialize_from_config
kv_cache_configs = [None] * self.vllm_args.parallel_config.world_size
kv_cache_configs = [None] * self.vllm_config.parallel_config.world_size
kv_cache_configs[self.rank] = kv_cache_config
self.worker.initialize_from_config(kv_cache_configs)
self.worker.compile_or_warm_up_model()
self.worker.initialize_cache(kv_cache_config.num_blocks, 0)
return kv_cache_config

@endpoint
async def get_vllm_args(self):
return self.vllm_args
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
Expand All @@ -488,21 +479,21 @@ async def _get_model_params(self) -> dict[str, torch.Tensor]:

def setup_worker(self):
"""Build and Instantiate vLLM worker"""
parallel_config = self.vllm_args.parallel_config
parallel_config = self.vllm_config.parallel_config
set_multiprocessing_worker_envs(parallel_config)
ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT")
distributed_init_method = get_distributed_init_method(ip, port)
all_kwargs = [{}] * parallel_config.world_size
local_rank = self.rank % torch.accelerator.device_count()
is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0
all_kwargs[self.rank] = {
"vllm_config": self.vllm_args,
"vllm_config": self.vllm_config,
"local_rank": local_rank,
"rank": self.rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker,
}
worker = WorkerWrapperBase(self.vllm_args, self.rank)
worker = WorkerWrapperBase(self.vllm_config, self.rank)
worker.init_worker(all_kwargs)
worker.init_device()
worker.load_model()
Expand Down
3 changes: 3 additions & 0 deletions tests/unit_tests/test_policy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_policy_default_initialization(self):
self.assertEqual(policy.engine_config.tensor_parallel_size, 1)
self.assertEqual(policy.engine_config.pipeline_parallel_size, 1)
self.assertFalse(policy.engine_config.enforce_eager)
self.assertTrue(policy.engine_config._is_v1_supported_oracle())

# Sampling defaults
self.assertEqual(policy.sampling_config.n, 1)
Expand Down Expand Up @@ -91,6 +92,7 @@ def test_policy_with_dict_configs(self):
self.assertEqual(policy.engine_config.tensor_parallel_size, 7777)
self.assertEqual(policy.engine_config.pipeline_parallel_size, 8888)
self.assertTrue(policy.engine_config.enforce_eager)
self.assertTrue(policy.engine_config._is_v1_supported_oracle())

self.assertEqual(policy.sampling_config.n, 1357)
# After __post_init__, guided_decoding becomes GuidedDecodingParams object when True
Expand Down Expand Up @@ -143,6 +145,7 @@ def test_policy_yaml_config_loading(self):
self.assertEqual(policy.engine_config.tensor_parallel_size, 1234)
self.assertEqual(policy.engine_config.pipeline_parallel_size, 5678)
self.assertTrue(policy.engine_config.enforce_eager)
self.assertTrue(policy.engine_config._is_v1_supported_oracle())

self.assertEqual(policy.sampling_config.n, 2468)
# After __post_init__, guided_decoding becomes GuidedDecodingParams object when True
Expand Down
Loading