Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __post_init__(self):

if isinstance(self.engine_args, Mapping):
self.engine_args = EngineArgs(**self.engine_args)
self.engine_args._is_v1_supported_oracle = lambda *_: True
self.engine_args._is_v1_supported_oracle = lambda *_: True

if isinstance(self.sampling_params, Mapping):
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
Expand Down
23 changes: 6 additions & 17 deletions tests/unit_tests/test_policy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_policy_default_initialization(self):
self.assertIsNone(policy.available_devices)

# Worker defaults
self.assertEqual(policy.engine_args.model, "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(policy.engine_args.model, "Qwen/Qwen3-0.6B")
self.assertEqual(policy.engine_args.tensor_parallel_size, 1)
self.assertEqual(policy.engine_args.pipeline_parallel_size, 1)
self.assertFalse(policy.engine_args.enforce_eager)
Expand All @@ -51,7 +51,7 @@ def test_policy_default_initialization(self):
# Sampling defaults
self.assertEqual(policy.sampling_params.n, 1)
self.assertFalse(policy.sampling_params.guided_decoding)
self.assertEqual(policy.sampling_params.max_tokens, 512)
self.assertEqual(policy.sampling_params.max_tokens, 16)

@pytest.mark.skipif(
_import_error(),
Expand All @@ -69,11 +69,8 @@ def test_policy_with_dict_configs(self):
"tensor_parallel_size": 7777,
"pipeline_parallel_size": 8888,
"enforce_eager": True,
"nested_config": {
"gpu_memory_utilization": 0.9,
"max_model_len": 4096,
"custom_settings": {"temperature": 0.7, "top_p": 0.9},
},
"gpu_memory_utilization": 0.9,
"max_model_len": 4096,
}

sampling_dict = {
Expand All @@ -94,22 +91,14 @@ def test_policy_with_dict_configs(self):
self.assertEqual(policy.engine_args.model, "test-model-6789")
self.assertEqual(policy.engine_args.tensor_parallel_size, 7777)
self.assertEqual(policy.engine_args.pipeline_parallel_size, 8888)
self.assertEqual(policy.engine_args.gpu_memory_utilization, 0.9)
self.assertEqual(policy.engine_args.max_model_len, 4096)
self.assertTrue(policy.engine_args.enforce_eager)
self.assertTrue(policy.engine_args._is_v1_supported_oracle())

self.assertEqual(policy.sampling_params.n, 1357)
self.assertEqual(policy.sampling_params.max_tokens, 2468)

# Test that engine_dict accepts and preserves nested dict structure
# The original engine_dict should remain unchanged and accessible
self.assertIn("nested_config", engine_dict)
self.assertEqual(engine_dict["nested_config"]["gpu_memory_utilization"], 0.9)
self.assertEqual(engine_dict["nested_config"]["max_model_len"], 4096)
self.assertEqual(
engine_dict["nested_config"]["custom_settings"]["temperature"], 0.7
)
self.assertEqual(engine_dict["nested_config"]["custom_settings"]["top_p"], 0.9)

@pytest.mark.skipif(
_import_error(),
reason="Import error, likely due to missing dependencies on CI.",
Expand Down
Loading