diff --git a/apps/grpo/main.py b/apps/grpo/main.py index ed33e7e2f..3c1b9e28f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -20,7 +20,7 @@ get_dcp_whole_state_dict_key, get_param_prefix, ) -from forge.actors.policy import Policy +from forge.actors.generator import Generator from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import RLTrainer @@ -79,6 +79,9 @@ def response_tensor(self) -> torch.Tensor: # Represents the group (G) of episodes in GRPO Group = list[Episode] +# Represents the Policy Model to collect data from +Policy = Generator + def collate( batches: list[Group], diff --git a/docs/source/api_generator.md b/docs/source/api_generator.md index c5aee3eec..31b67c03c 100644 --- a/docs/source/api_generator.md +++ b/docs/source/api_generator.md @@ -1,26 +1,26 @@ # Generator ```{eval-rst} -.. currentmodule:: forge.actors.policy +.. currentmodule:: forge.actors.generator ``` The Generator (Policy) is the core inference engine in TorchForge, built on top of [vLLM](https://docs.vllm.ai/en/latest/). It manages model serving, text generation, and weight updates for reinforcement learning workflows. -## Policy +## Generator ```{eval-rst} -.. autoclass:: Policy +.. autoclass:: Generator :members: generate, update_weights, get_version, stop :exclude-members: __init__, launch :no-inherited-members: ``` -## PolicyWorker +## GeneratorWorker ```{eval-rst} -.. autoclass:: PolicyWorker +.. autoclass:: GeneratorWorker :members: execute_model, update, setup_kv_cache :show-inheritance: :exclude-members: __init__ diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index 1bffade94..cb0fb9300 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -5,8 +5,7 @@ # LICENSE file in the root directory of this source tree. __all__ = [ - "Policy", - "PolicyRouter", + "Generator" "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel", @@ -15,10 +14,10 @@ def __getattr__(name): - if name == "Policy": - from .policy import Policy + if name == "Generator": + from .policy import Generator - return Policy + return Generator elif name == "PolicyRouter": from .policy import PolicyRouter diff --git a/src/forge/actors/policy.py b/src/forge/actors/generator.py similarity index 83% rename from src/forge/actors/policy.py rename to src/forge/actors/generator.py index 7a1cdfd15..8f8cf8fc7 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/generator.py @@ -53,7 +53,7 @@ from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from forge.env import TORCHSTORE_USE_RDMA -from forge.interfaces import Policy as PolicyInterface +from forge.interfaces import Policy as GeneratorInterface from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig @@ -63,8 +63,8 @@ @dataclass -class Policy(PolicyInterface): - """Instance of a vLLM-based Policy. +class Generator(GeneratorInterface): + """Instance of a vLLM-based Generator. This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The main difference is that all communications are controlled here via Monarch's proc meshes. @@ -77,14 +77,14 @@ class Policy(PolicyInterface): Example: - >>> policy = await Policy.options(procs=1, num_replicas=1, with_gpus=True).as_service( + >>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service( ... engine_args=EngineArgs(...), ... sampling_params=SamplingParams(...), ... ) - >>> await policy.generate("Tell me a joke") + >>> await generator.generate("Tell me a joke") Completion(prompt="Tell me a joke", text="A: Why did the chicken cross the road? B: To get to the other side.", token_ids=[...], logprobs=[...]) - >>> await policy.shutdown() + >>> await generator.shutdown() """ engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) @@ -96,15 +96,15 @@ class Policy(PolicyInterface): # Remaining variables are initialized in self.setup() lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) - policy_worker: PolicyWorker | None = None + generator_worker: GeneratorWorker | None = None def __post_init__(self): super().__init__() self._run_task: asyncio.Task | None = None - self._policy_proc: ProcMesh | None = None + self._generator_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.running = False - self.policy_version: int = 0 + self.generator_version: int = 0 if isinstance(self.engine_args, Mapping): self.engine_args = EngineArgs(**self.engine_args) @@ -116,7 +116,7 @@ def __post_init__(self): @classmethod async def launch( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type["Policy"], + cls: type["Generator"], *, engine_args: EngineArgs | Mapping = EngineArgs(), sampling_params: SamplingParams | Mapping = SamplingParams(), @@ -125,14 +125,14 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] TORCHSTORE_USE_RDMA.get_value() == 0 ), # torchstore currently only accepts 0 or 1 **kwargs, - ) -> "Policy": - """Launch the policy with its workers. + ) -> "Generator": + """Launch the Generator with its workers. - We overwrite the default Service launch method in order to setup Actors (PolicyWorker) within this "coordinating" Actor. - We first create a proc_mesh for the workers, then a proc_mesh for the policy, and then we spawn the workers - and the policy in setup. + We overwrite the default Service launch method in order to setup Actors (GeneratorWorker) within this "coordinating" Actor. + We first create a proc_mesh for the workers, then a proc_mesh for the generator, and then we spawn the workers + and the generator in setup. - The args here generally should match those in the `__init__` method of the Policy class. + The args here generally should match those in the `__init__` method of the Generator class. """ # Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES process_config: ProcessConfig = ProcessConfig( @@ -144,16 +144,16 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] worker_procs = await get_proc_mesh(process_config=process_config) # TODO - issues/144 we will want to ensure colocation with workers - # We're currently locating the Policy on the local host proc mesh + # We're currently locating the Generator on the local host proc mesh # vLLM initialization without setting env variables at proc_mesh creation # level leads to issues. # Once we can create multiple proc meshes on a host mesh, we can ensure # host colocation - policy_proc_config = copy(process_config) - policy_proc_config.procs = 1 - policy_proc_config.hosts = None - policy_proc_config.with_gpus = False - policy_proc = await get_proc_mesh(process_config=policy_proc_config) + generator_proc_config = copy(process_config) + generator_proc_config.procs = 1 + generator_proc_config.hosts = None + generator_proc_config.with_gpus = False + generator_proc = await get_proc_mesh(process_config=generator_proc_config) if isinstance(engine_args, Mapping): engine_args = EngineArgs(**engine_args) @@ -162,7 +162,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) workers = worker_procs.spawn( - "vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp + "vllm_worker", GeneratorWorker, vllm_config=vllm_config, use_dcp=use_dcp ) if isinstance(sampling_params, Mapping): @@ -172,33 +172,33 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] # TODO - expand support so name can stick within kwargs actor_name = kwargs.pop("name", cls.__name__) - policy = policy_proc.spawn( + generator = generator_proc.spawn( actor_name, cls, engine_args=engine_args, sampling_params=sampling_params, available_devices=available_devices, - policy_worker=workers, + generator_worker=workers, **kwargs, ) - policy._policy_proc = policy_proc - policy._worker_procs = worker_procs - await policy.setup.call() - return policy + generator._generator_proc = generator_proc + generator._worker_procs = worker_procs + await generator.setup.call() + return generator @endpoint async def setup(self): """Mirrors the __init__ of vLLM's LLMEngine.""" - if self.policy_worker is None: + if self.generator_worker is None: raise RuntimeError( - "Policy worker should not be None. Usually it would be attached to Policy in the ``launch`` method." + "Geneator worker should not be None. Usually it would be attached to Generator in the ``launch`` method." ) - await self.policy_worker.setup.call() + await self.generator_worker.setup.call() self.request_id = 0 self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {} - # TODO: Investigate whether this can be combined with `policy.running` + # TODO: Investigate whether this can be combined with `generator.running` self.accepting_requests = True self.request_lock = asyncio.Condition() # Guard for accepting_requests @@ -223,7 +223,7 @@ async def setup(self): self.output_processor = OutputProcessor(tokenizer, log_stats=None) # Configure KV caches - kv_cache_configs = await self.policy_worker.setup_kv_cache.call() + kv_cache_configs = await self.generator_worker.setup_kv_cache.call() _, kv_cache_config = next(kv_cache_configs.items()) vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks vllm_config.cache_config.num_cpu_blocks = 0 @@ -255,9 +255,9 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: Returns: list[Completion]: n completions from vLLM based on your prompt. """ - t = Tracer("policy_perf/generate", timer="gpu") + t = Tracer("generator_perf/generate", timer="gpu") t.start() - record_metric("policy/generate/count_requests", 1, Reduce.SUM) + record_metric("generator/generate/count_requests", 1, Reduce.SUM) self.request_id += 1 % sys.maxsize request_id = str(self.request_id) @@ -318,7 +318,7 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: # Log some metrics record_metric( - "policy/generate/count_sequences_completed", + "generator/generate/count_sequences_completed", len(completions), Reduce.SUM, ) @@ -326,13 +326,13 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: for completion in completions: num_generated_tokens = len(completion.token_ids) record_metric( - "policy/generate/sum_tokens_generated", + "generator/generate/sum_tokens_generated", num_generated_tokens, Reduce.SUM, ) record_metric( - "policy/generate/avg_tokens_generated", + "generator/generate/avg_tokens_generated", num_generated_tokens, Reduce.MEAN, ) @@ -360,7 +360,7 @@ async def run(self) -> None: self.running = True while self.running: scheduler_output = self.scheduler.schedule() - worker_outputs = await self.policy_worker.execute_model.call( + worker_outputs = await self.generator_worker.execute_model.call( scheduler_output ) @@ -387,18 +387,18 @@ async def run(self) -> None: self.request_lock.notify_all() @endpoint - async def update_weights(self, policy_version: int) -> None: - """Update weights on base model from a policy version to be found in a torchstore volume. + async def update_weights(self, version: int) -> None: + """Update weights on base model from a generator version to be found in a torchstore volume. Args: - policy_version (int): Policy version from which to update. This will correspond to a key in a + generator_version (int): Generator version from which to update. This will correspond to a key in a torchstore volume. Example: >>> trainer.train_step(...) >>> version += 1 >>> await trainer.push_weights() - >>> policy.update_weights(version) + >>> generator.update_weights(version) """ # Serialize updates (only one update at a time) async with self.update_lock: @@ -409,12 +409,12 @@ async def update_weights(self, policy_version: int) -> None: if curr_requests: # Record pending requests metrics record_metric( - "policy_perf/update_weights/avg_pending_requests", + "generator_perf/update_weights/avg_pending_requests", len(curr_requests), Reduce.MEAN, ) record_metric( - "policy_perf/update_weights/max_pending_requests", + "generator_perf/update_weights/max_pending_requests", len(curr_requests), Reduce.MAX, ) @@ -422,16 +422,18 @@ async def update_weights(self, policy_version: int) -> None: # Wait until all pending requests have been processed # TODO: If generating long sequences, this might be long and will block - # policy weight updates + # generator weight updates await self.request_lock.wait_for(lambda: len(self.requests) == 0) # Record weight update metrics - record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM) + record_metric( + "generator/update_weights/count_weight_updates", 1, Reduce.SUM + ) logger.debug(f"Starting weight update on {self.__class__.__name__}") - # Call update_weights on every policy_worker - await self.policy_worker.update_weights.call(policy_version) - self.policy_version = policy_version + # Call update_weights on every generator_worker + await self.generator_worker.update_weights.call(version=version) + self.generator_version = version # After updating the weights, we need to reset the KV cache self.scheduler.reset_prefix_cache() @@ -441,12 +443,17 @@ async def update_weights(self, policy_version: int) -> None: self.accepting_requests = True self.request_lock.notify_all() - logger.info(f"Weight update completed (now v{self.policy_version})") + logger.info(f"Weight update completed (now v{self.generator_version})") @endpoint async def _reset_prefix_cache(self): self.scheduler.reset_prefix_cache() + @endpoint + async def get_version(self) -> int: + """Get the current generator version.""" + return self.generator_version + @endpoint async def stop(self): self.running = False @@ -467,7 +474,7 @@ def _to_completions(self, request_output: RequestOutput) -> list[Completion]: prompt_ids=torch.tensor(prompt_token_ids), token_ids=torch.tensor(output.token_ids), logprobs=self._extract_logprobs(output), - generator_version=self.policy_version, + generator_version=self.generator_version, metadata={"num_cached_tokens": request_output.num_cached_tokens}, ) ) @@ -485,41 +492,41 @@ def _extract_logprobs(self, sample: CompletionOutput) -> torch.Tensor | None: @classmethod async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type["Policy"], actor: "Policy" + cls: type["Generator"], actor: "Generator" ): assert ( - actor._policy_proc is not None - ), "Tried to shutdown a policy that was not initialized correctly" + actor._generator_proc is not None + ), "Tried to shutdown a generator that was not initialized correctly" assert ( actor._worker_procs is not None - ), "Tried to shutdown a policy that was not initialized correctly" + ), "Tried to shutdown a generator that was not initialized correctly" # TODO - may want to expand stop to gracefully respond to # ongoing requests. await actor.stop.call() await stop_proc_mesh(actor._worker_procs) - await stop_proc_mesh(actor._policy_proc) + await stop_proc_mesh(actor._generator_proc) @endpoint async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only.""" - logger.info("[Policy] save model parameters for testing.") - await self.policy_worker._test_save_model_params.call() + logger.info("[Generator] save model parameters for testing.") + await self.generator_worker._test_save_model_params.call() @endpoint async def _test_validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" - logger.info("[Policy] start validating model parameters.") - return await self.policy_worker._test_validate_model_params.call(validate_fn) + logger.info("[Generator] start validating model parameters.") + return await self.generator_worker._test_validate_model_params.call(validate_fn) @dataclass -class PolicyWorker(ForgeActor): +class GeneratorWorker(ForgeActor): """Mirrors a vLLM GPUWorker https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/worker/gpu_worker.py - In general, this class should not be instantiated or called directly. Rather, the Policy controls - the creation and invocation of all PolicyWorkers. + In general, this class should not be instantiated or called directly. Rather, the Generator controls + the creation and invocation of all GeneratorWorker. """ vllm_config: VllmConfig @@ -590,13 +597,13 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: return self.worker.execute_model(schedule) @endpoint - async def update_weights(self, policy_version: int) -> None: + async def update_weights(self, version: int) -> None: model = self.worker.model_runner.model - prefix = get_param_prefix(policy_version) + prefix = get_param_prefix(version) matching_keys = await ts.keys(prefix) - dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(policy_version) + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) loaded_weights = set() - t = Tracer("policy_worker_perf/update_weights", timer="gpu") + t = Tracer("worker_perf/update_weights", timer="gpu") t.start() # Entire state dict is stored in a single DCP handle if dcp_whole_state_dict_key in matching_keys: @@ -612,7 +619,7 @@ async def update_weights(self, policy_version: int) -> None: # We can't pass a generator since vllm load_weights is not async. # Instead, we just call load_weights with one parameter at a time. for name in hf_param_names: - param_key = get_param_key(policy_version, name) + param_key = get_param_key(version, name) param = await ts.get(param_key) loaded = model.load_weights([(name, param)]) del param @@ -622,18 +629,18 @@ async def update_weights(self, policy_version: int) -> None: @endpoint async def _test_save_model_params(self): """Save model parameters before weight update, used for tesing purposes only.""" - logger.info("[PolicyWorker] save model parameters for testing.") + logger.info("[GeneratorWorker] save model parameters for testing.") for name, param in self.worker.model_runner.model.named_parameters(): self._test_prev_params[name] = param.detach().cpu() logger.info( - "[PolicyWorker] finished saving model parameters, len = %d", + "[GeneratorWorker] finished saving model parameters, len = %d", len(self._test_prev_params), ) @endpoint async def _test_validate_model_params(self, validate_fn): """Validate updated model params using validate_fn.""" - logger.info("[PolicyWorker] start validating model parameters.") + logger.info("[GeneratorWorker] start validating model parameters.") return validate_fn( self._test_prev_params, self.worker.model_runner.model, logger ) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 2f9addfe6..83ddd349e 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -62,7 +62,7 @@ async def get_or_create_metric_logger( }) # Initialize services... - policy = await Policy.as_service(...) + policy = await Generator.as_service(...) # Training loop for step in range(max_steps): diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index bbc41bc5c..7edf0fcf3 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -12,7 +12,7 @@ import torch import torchstore as ts -from forge.actors.policy import Policy +from forge.actors.generator import Generator from forge.actors.trainer import RLTrainer from forge.cli.config import resolve_hf_hub_paths @@ -203,7 +203,7 @@ async def test_sanity_check(self, request): trainer_cfg["dcp_path"] = tmpdir policy, rl_trainer = await asyncio.gather( *[ - Policy.options(**services_policy_cfg).as_service(**cfg.policy), + Generator.options(**services_policy_cfg).as_service(**cfg.policy), MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg), ] ) @@ -224,7 +224,7 @@ async def test_sanity_check(self, request): for _, e in errs.items(): assert not e, f"Validation failed with exception: {e}" - await policy.update_weights.fanout(policy_version=v1) + await policy.update_weights.fanout(version=v1) all_errs = await policy._test_validate_model_params.fanout( validate_fn_all_zeros ) @@ -233,7 +233,7 @@ async def test_sanity_check(self, request): assert not e, f"Validation failed with exception: {e}" # Reloading v0, getting back original weights - await policy.update_weights.fanout(policy_version=v0) + await policy.update_weights.fanout(version=v0) all_errs = await policy._test_validate_model_params.fanout(validate_fn) for errs in all_errs: for _, e in errs.items(): diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index e862ac60d..4f77cd939 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -16,7 +16,7 @@ import torch.nn.functional as F import torchstore as ts from forge.actors._torchstore_utils import get_param_key -from forge.actors.policy import Policy +from forge.actors.generator import Generator from forge.actors.replay_buffer import ReplayBuffer from forge.cli.config import parse from forge.controller.actor import ForgeActor @@ -445,7 +445,7 @@ async def main(cfg: DictConfig): ref_model, ) = await asyncio.gather( DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), - Policy.options(**cfg.services.policy).as_service(**cfg.policy), + Generator.options(**cfg.services.policy).as_service(**cfg.policy), Trainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer), ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(**cfg.replay_buffer), RewardActor.options(**cfg.services.reward_actor).as_service(), diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 54b093841..5473b5fc4 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -13,7 +13,7 @@ import os -from forge.actors.policy import Policy +from forge.actors.generator import Generator from forge.cli.config import parse from forge.controller.provisioner import init_provisioner, shutdown @@ -40,7 +40,7 @@ async def run(cfg: DictConfig): prompt = "Tell me a joke" print("Spawning service...") - policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy) + policy = await Generator.options(**cfg.services.policy).as_service(**cfg.policy) import time diff --git a/tests/unit_tests/test_generator_config.py b/tests/unit_tests/test_generator_config.py new file mode 100644 index 000000000..1c64e42e2 --- /dev/null +++ b/tests/unit_tests/test_generator_config.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import pytest +import yaml + + +def _import_error(): + """Check if there are import errors that would cause CI failures.""" + try: + import forge.actors.generator # noqa: F401 + + return False + except ImportError: + return True + + +class TestGeneratorConfig(unittest.TestCase): + """Test suite for Generator configuration handling after PolicyConfig removal.""" + + @pytest.mark.skipif( + _import_error(), + reason="Import error, likely due to missing dependencies on CI.", + ) + def test_generator_default_initialization(self): + """Generator initializes with default values.""" + from forge.actors.generator import Generator + from vllm.engine.arg_utils import EngineArgs + from vllm.sampling_params import SamplingParams + + generator = Generator() + + # Default factories + self.assertIsInstance(generator.engine_args, EngineArgs) + self.assertIsInstance(generator.sampling_params, SamplingParams) + self.assertIsNone(generator.available_devices) + + # Worker defaults + self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B") + self.assertEqual(generator.engine_args.tensor_parallel_size, 1) + self.assertEqual(generator.engine_args.pipeline_parallel_size, 1) + self.assertFalse(generator.engine_args.enforce_eager) + self.assertTrue(generator.engine_args._is_v1_supported_oracle()) + + # Sampling defaults + self.assertEqual(generator.sampling_params.n, 1) + self.assertFalse(generator.sampling_params.guided_decoding) + self.assertEqual(generator.sampling_params.max_tokens, 16) + + @pytest.mark.skipif( + _import_error(), + reason="Import error, likely due to missing dependencies on CI.", + ) + def test_generator_with_dict_configs(self): + """Generator accepts dicts for engine_config and sampling_config, including nested dicts.""" + from forge.actors.generator import Generator + from vllm.engine.arg_utils import EngineArgs + from vllm.sampling_params import SamplingParams + + # Test with nested dict structure + engine_dict = { + "model": "test-model-6789", + "tensor_parallel_size": 7777, + "pipeline_parallel_size": 8888, + "enforce_eager": True, + "gpu_memory_utilization": 0.9, + "max_model_len": 4096, + } + + sampling_dict = { + "n": 1357, + "max_tokens": 2468, + } + + generator = Generator( + engine_args=engine_dict, + sampling_params=sampling_dict, + available_devices="test-gpu-device-abcd", + ) + + self.assertIsInstance(generator.engine_args, EngineArgs) + self.assertIsInstance(generator.sampling_params, SamplingParams) + + # Test basic fields + self.assertEqual(generator.engine_args.model, "test-model-6789") + self.assertEqual(generator.engine_args.tensor_parallel_size, 7777) + self.assertEqual(generator.engine_args.pipeline_parallel_size, 8888) + self.assertEqual(generator.engine_args.gpu_memory_utilization, 0.9) + self.assertEqual(generator.engine_args.max_model_len, 4096) + self.assertTrue(generator.engine_args.enforce_eager) + self.assertTrue(generator.engine_args._is_v1_supported_oracle()) + + self.assertEqual(generator.sampling_params.n, 1357) + self.assertEqual(generator.sampling_params.max_tokens, 2468) + + @pytest.mark.skipif( + _import_error(), + reason="Import error, likely due to missing dependencies on CI.", + ) + def test_generator_yaml_config_loading(self): + """Generator can be constructed from a YAML config file.""" + from forge.actors.generator import Generator + + yaml_content = """ + engine_args: + model: "yaml-test-model-9876" + tensor_parallel_size: 1234 + pipeline_parallel_size: 5678 + enforce_eager: true + + sampling_params: + n: 2468 + max_tokens: 1357 + + available_devices: "yaml-test-device-xyz" + """ + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(yaml_content) + f.flush() + + with open(f.name, "r") as yaml_file: + config = yaml.safe_load(yaml_file) + + generator = Generator(**config) + self.assertEqual(generator.engine_args.model, "yaml-test-model-9876") + self.assertEqual(generator.engine_args.tensor_parallel_size, 1234) + self.assertEqual(generator.engine_args.pipeline_parallel_size, 5678) + self.assertTrue(generator.engine_args.enforce_eager) + self.assertTrue(generator.engine_args._is_v1_supported_oracle()) + + self.assertEqual(generator.sampling_params.n, 2468) + self.assertEqual(generator.sampling_params.max_tokens, 1357) + + self.assertEqual(generator.available_devices, "yaml-test-device-xyz") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_policy_config.py b/tests/unit_tests/test_policy_config.py deleted file mode 100644 index 31288e0bb..000000000 --- a/tests/unit_tests/test_policy_config.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import tempfile -import unittest - -import pytest -import yaml - - -def _import_error(): - """Check if there are import errors that would cause CI failures.""" - try: - import forge.actors.policy # noqa: F401 - - return False - except ImportError: - return True - - -class TestPolicyConfig(unittest.TestCase): - """Test suite for Policy configuration handling after PolicyConfig removal.""" - - @pytest.mark.skipif( - _import_error(), - reason="Import error, likely due to missing dependencies on CI.", - ) - def test_policy_default_initialization(self): - """Policy initializes with default values.""" - from forge.actors.policy import Policy - from vllm.engine.arg_utils import EngineArgs - from vllm.sampling_params import SamplingParams - - policy = Policy() - - # Default factories - self.assertIsInstance(policy.engine_args, EngineArgs) - self.assertIsInstance(policy.sampling_params, SamplingParams) - self.assertIsNone(policy.available_devices) - - # Worker defaults - self.assertEqual(policy.engine_args.model, "Qwen/Qwen3-0.6B") - self.assertEqual(policy.engine_args.tensor_parallel_size, 1) - self.assertEqual(policy.engine_args.pipeline_parallel_size, 1) - self.assertFalse(policy.engine_args.enforce_eager) - self.assertTrue(policy.engine_args._is_v1_supported_oracle()) - - # Sampling defaults - self.assertEqual(policy.sampling_params.n, 1) - self.assertFalse(policy.sampling_params.guided_decoding) - self.assertEqual(policy.sampling_params.max_tokens, 16) - - @pytest.mark.skipif( - _import_error(), - reason="Import error, likely due to missing dependencies on CI.", - ) - def test_policy_with_dict_configs(self): - """Policy accepts dicts for engine_args and sampling_params, including nested dicts.""" - from forge.actors.policy import Policy - from vllm.engine.arg_utils import EngineArgs - from vllm.sampling_params import SamplingParams - - # Test with nested dict structure - engine_dict = { - "model": "test-model-6789", - "tensor_parallel_size": 7777, - "pipeline_parallel_size": 8888, - "enforce_eager": True, - "gpu_memory_utilization": 0.9, - "max_model_len": 4096, - } - - sampling_dict = { - "n": 1357, - "max_tokens": 2468, - } - - policy = Policy( - engine_args=engine_dict, - sampling_params=sampling_dict, - available_devices="test-gpu-device-abcd", - ) - - self.assertIsInstance(policy.engine_args, EngineArgs) - self.assertIsInstance(policy.sampling_params, SamplingParams) - - # Test basic fields - self.assertEqual(policy.engine_args.model, "test-model-6789") - self.assertEqual(policy.engine_args.tensor_parallel_size, 7777) - self.assertEqual(policy.engine_args.pipeline_parallel_size, 8888) - self.assertEqual(policy.engine_args.gpu_memory_utilization, 0.9) - self.assertEqual(policy.engine_args.max_model_len, 4096) - self.assertTrue(policy.engine_args.enforce_eager) - self.assertTrue(policy.engine_args._is_v1_supported_oracle()) - - self.assertEqual(policy.sampling_params.n, 1357) - self.assertEqual(policy.sampling_params.max_tokens, 2468) - - @pytest.mark.skipif( - _import_error(), - reason="Import error, likely due to missing dependencies on CI.", - ) - def test_policy_yaml_config_loading(self): - """Policy can be constructed from a YAML config file.""" - from forge.actors.policy import Policy - - yaml_content = """ - engine_args: - model: "yaml-test-model-9876" - tensor_parallel_size: 1234 - pipeline_parallel_size: 5678 - enforce_eager: true - - sampling_params: - n: 2468 - max_tokens: 1357 - - available_devices: "yaml-test-device-xyz" - """ - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write(yaml_content) - f.flush() - - with open(f.name, "r") as yaml_file: - config = yaml.safe_load(yaml_file) - - policy = Policy(**config) - - self.assertEqual(policy.engine_args.model, "yaml-test-model-9876") - self.assertEqual(policy.engine_args.tensor_parallel_size, 1234) - self.assertEqual(policy.engine_args.pipeline_parallel_size, 5678) - self.assertTrue(policy.engine_args.enforce_eager) - self.assertTrue(policy.engine_args._is_v1_supported_oracle()) - - self.assertEqual(policy.sampling_params.n, 2468) - self.assertEqual(policy.sampling_params.max_tokens, 1357) - - self.assertEqual(policy.available_devices, "yaml-test-device-xyz") - - -if __name__ == "__main__": - unittest.main()