diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 686ec973b..dbffaf740 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -83,7 +83,7 @@ def __post_init__(self): if isinstance(self.engine_args, Mapping): self.engine_args = EngineArgs(**self.engine_args) - self.engine_args._is_v1_supported_oracle = lambda *_: True + self.engine_args._is_v1_supported_oracle = lambda *_: True if isinstance(self.sampling_params, Mapping): self.sampling_params = SamplingParams.from_optional(**self.sampling_params) diff --git a/tests/unit_tests/test_policy_config.py b/tests/unit_tests/test_policy_config.py index c1a68e540..31288e0bb 100644 --- a/tests/unit_tests/test_policy_config.py +++ b/tests/unit_tests/test_policy_config.py @@ -42,7 +42,7 @@ def test_policy_default_initialization(self): self.assertIsNone(policy.available_devices) # Worker defaults - self.assertEqual(policy.engine_args.model, "meta-llama/Llama-3.1-8B-Instruct") + self.assertEqual(policy.engine_args.model, "Qwen/Qwen3-0.6B") self.assertEqual(policy.engine_args.tensor_parallel_size, 1) self.assertEqual(policy.engine_args.pipeline_parallel_size, 1) self.assertFalse(policy.engine_args.enforce_eager) @@ -51,7 +51,7 @@ def test_policy_default_initialization(self): # Sampling defaults self.assertEqual(policy.sampling_params.n, 1) self.assertFalse(policy.sampling_params.guided_decoding) - self.assertEqual(policy.sampling_params.max_tokens, 512) + self.assertEqual(policy.sampling_params.max_tokens, 16) @pytest.mark.skipif( _import_error(), @@ -69,11 +69,8 @@ def test_policy_with_dict_configs(self): "tensor_parallel_size": 7777, "pipeline_parallel_size": 8888, "enforce_eager": True, - "nested_config": { - "gpu_memory_utilization": 0.9, - "max_model_len": 4096, - "custom_settings": {"temperature": 0.7, "top_p": 0.9}, - }, + "gpu_memory_utilization": 0.9, + "max_model_len": 4096, } sampling_dict = { @@ -94,22 +91,14 @@ def test_policy_with_dict_configs(self): self.assertEqual(policy.engine_args.model, "test-model-6789") self.assertEqual(policy.engine_args.tensor_parallel_size, 7777) self.assertEqual(policy.engine_args.pipeline_parallel_size, 8888) + self.assertEqual(policy.engine_args.gpu_memory_utilization, 0.9) + self.assertEqual(policy.engine_args.max_model_len, 4096) self.assertTrue(policy.engine_args.enforce_eager) self.assertTrue(policy.engine_args._is_v1_supported_oracle()) self.assertEqual(policy.sampling_params.n, 1357) self.assertEqual(policy.sampling_params.max_tokens, 2468) - # Test that engine_dict accepts and preserves nested dict structure - # The original engine_dict should remain unchanged and accessible - self.assertIn("nested_config", engine_dict) - self.assertEqual(engine_dict["nested_config"]["gpu_memory_utilization"], 0.9) - self.assertEqual(engine_dict["nested_config"]["max_model_len"], 4096) - self.assertEqual( - engine_dict["nested_config"]["custom_settings"]["temperature"], 0.7 - ) - self.assertEqual(engine_dict["nested_config"]["custom_settings"]["top_p"], 0.9) - @pytest.mark.skipif( _import_error(), reason="Import error, likely due to missing dependencies on CI.",