Skip to content

Commit 063afe6

Browse files
committed
solve conflict
1 parent ba74b43 commit 063afe6

File tree

1 file changed

+5
-117
lines changed

1 file changed

+5
-117
lines changed
Lines changed: 5 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,6 @@
1-
import tempfile
2-
import unittest
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
36

4-
import yaml
5-
6-
from forge.actors.policy import Policy, SamplingOverrides, WorkerConfig
7-
8-
9-
class TestPolicyConfig(unittest.TestCase):
10-
"""Test suite for Policy configuration handling after PolicyConfig removal."""
11-
12-
def test_policy_default_initialization(self):
13-
"""Policy initializes with default values."""
14-
policy = Policy()
15-
16-
# Default factories
17-
self.assertIsInstance(policy.worker_params, WorkerConfig)
18-
self.assertIsInstance(policy.sampling_overrides, SamplingOverrides)
19-
self.assertIsNone(policy.available_devices)
20-
21-
# Worker defaults
22-
self.assertEqual(policy.worker_params.model, "meta-llama/Llama-3.1-8B-Instruct")
23-
self.assertEqual(policy.worker_params.tensor_parallel_size, 1)
24-
self.assertEqual(policy.worker_params.pipeline_parallel_size, 1)
25-
self.assertFalse(policy.worker_params.enforce_eager)
26-
27-
# Sampling defaults
28-
self.assertEqual(policy.sampling_overrides.num_samples, 1)
29-
self.assertFalse(policy.sampling_overrides.guided_decoding)
30-
self.assertEqual(policy.sampling_overrides.max_tokens, 512)
31-
32-
def test_policy_with_dict_configs(self):
33-
"""Policy accepts dicts for worker_params and sampling_overrides."""
34-
worker_dict = {
35-
"model": "test-model-6789",
36-
"tensor_parallel_size": 7777,
37-
"pipeline_parallel_size": 8888,
38-
"enforce_eager": True,
39-
}
40-
41-
sampling_dict = {
42-
"num_samples": 1357,
43-
"guided_decoding": True,
44-
"max_tokens": 2468,
45-
}
46-
47-
policy = Policy(
48-
worker_params=worker_dict,
49-
sampling_overrides=sampling_dict,
50-
available_devices="test-gpu-device-abcd",
51-
)
52-
53-
self.assertIsInstance(policy.worker_params, WorkerConfig)
54-
self.assertIsInstance(policy.sampling_overrides, SamplingOverrides)
55-
56-
self.assertEqual(policy.worker_params.model, "test-model-6789")
57-
self.assertEqual(policy.worker_params.tensor_parallel_size, 7777)
58-
self.assertEqual(policy.worker_params.pipeline_parallel_size, 8888)
59-
self.assertTrue(policy.worker_params.enforce_eager)
60-
61-
self.assertEqual(policy.sampling_overrides.num_samples, 1357)
62-
self.assertTrue(policy.sampling_overrides.guided_decoding)
63-
self.assertEqual(policy.sampling_overrides.max_tokens, 2468)
64-
65-
def test_policy_yaml_config_loading(self):
66-
"""Policy can be constructed from a YAML config file."""
67-
yaml_content = """
68-
worker_params:
69-
model: "yaml-test-model-9876"
70-
tensor_parallel_size: 1234
71-
pipeline_parallel_size: 5678
72-
enforce_eager: true
73-
74-
sampling_overrides:
75-
num_samples: 2468
76-
guided_decoding: true
77-
max_tokens: 1357
78-
79-
available_devices: "yaml-test-device-xyz"
80-
"""
81-
82-
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
83-
f.write(yaml_content)
84-
f.flush()
85-
86-
with open(f.name, "r") as yaml_file:
87-
config = yaml.safe_load(yaml_file)
88-
89-
policy = Policy(**config)
90-
91-
self.assertEqual(policy.worker_params.model, "yaml-test-model-9876")
92-
self.assertEqual(policy.worker_params.tensor_parallel_size, 1234)
93-
self.assertEqual(policy.worker_params.pipeline_parallel_size, 5678)
94-
self.assertTrue(policy.worker_params.enforce_eager)
95-
96-
self.assertEqual(policy.sampling_overrides.num_samples, 2468)
97-
self.assertTrue(policy.sampling_overrides.guided_decoding)
98-
self.assertEqual(policy.sampling_overrides.max_tokens, 1357)
99-
100-
self.assertEqual(policy.available_devices, "yaml-test-device-xyz")
101-
102-
def test_workerconfig_ignores_invalid_keys(self):
103-
"""WorkerConfig.from_dict ignores unexpected keys."""
104-
worker_dict = {
105-
"model": "custom-model",
106-
"tensor_parallel_size": 2,
107-
"invalid_key_123": "should be ignored",
108-
}
109-
110-
config = WorkerConfig.from_dict(worker_dict)
111-
112-
self.assertEqual(config.model, "custom-model")
113-
self.assertEqual(config.tensor_parallel_size, 2)
114-
self.assertFalse(hasattr(config, "invalid_key_123"))
115-
116-
117-
if __name__ == "__main__":
118-
unittest.main()

0 commit comments

Comments
 (0)