|
1 | | -import tempfile |
2 | | -import unittest |
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
3 | 6 |
|
4 | | -import yaml |
5 | | - |
6 | | -from forge.actors.policy import Policy, SamplingOverrides, WorkerConfig |
7 | | - |
8 | | - |
9 | | -class TestPolicyConfig(unittest.TestCase): |
10 | | - """Test suite for Policy configuration handling after PolicyConfig removal.""" |
11 | | - |
12 | | - def test_policy_default_initialization(self): |
13 | | - """Policy initializes with default values.""" |
14 | | - policy = Policy() |
15 | | - |
16 | | - # Default factories |
17 | | - self.assertIsInstance(policy.worker_params, WorkerConfig) |
18 | | - self.assertIsInstance(policy.sampling_overrides, SamplingOverrides) |
19 | | - self.assertIsNone(policy.available_devices) |
20 | | - |
21 | | - # Worker defaults |
22 | | - self.assertEqual(policy.worker_params.model, "meta-llama/Llama-3.1-8B-Instruct") |
23 | | - self.assertEqual(policy.worker_params.tensor_parallel_size, 1) |
24 | | - self.assertEqual(policy.worker_params.pipeline_parallel_size, 1) |
25 | | - self.assertFalse(policy.worker_params.enforce_eager) |
26 | | - |
27 | | - # Sampling defaults |
28 | | - self.assertEqual(policy.sampling_overrides.num_samples, 1) |
29 | | - self.assertFalse(policy.sampling_overrides.guided_decoding) |
30 | | - self.assertEqual(policy.sampling_overrides.max_tokens, 512) |
31 | | - |
32 | | - def test_policy_with_dict_configs(self): |
33 | | - """Policy accepts dicts for worker_params and sampling_overrides.""" |
34 | | - worker_dict = { |
35 | | - "model": "test-model-6789", |
36 | | - "tensor_parallel_size": 7777, |
37 | | - "pipeline_parallel_size": 8888, |
38 | | - "enforce_eager": True, |
39 | | - } |
40 | | - |
41 | | - sampling_dict = { |
42 | | - "num_samples": 1357, |
43 | | - "guided_decoding": True, |
44 | | - "max_tokens": 2468, |
45 | | - } |
46 | | - |
47 | | - policy = Policy( |
48 | | - worker_params=worker_dict, |
49 | | - sampling_overrides=sampling_dict, |
50 | | - available_devices="test-gpu-device-abcd", |
51 | | - ) |
52 | | - |
53 | | - self.assertIsInstance(policy.worker_params, WorkerConfig) |
54 | | - self.assertIsInstance(policy.sampling_overrides, SamplingOverrides) |
55 | | - |
56 | | - self.assertEqual(policy.worker_params.model, "test-model-6789") |
57 | | - self.assertEqual(policy.worker_params.tensor_parallel_size, 7777) |
58 | | - self.assertEqual(policy.worker_params.pipeline_parallel_size, 8888) |
59 | | - self.assertTrue(policy.worker_params.enforce_eager) |
60 | | - |
61 | | - self.assertEqual(policy.sampling_overrides.num_samples, 1357) |
62 | | - self.assertTrue(policy.sampling_overrides.guided_decoding) |
63 | | - self.assertEqual(policy.sampling_overrides.max_tokens, 2468) |
64 | | - |
65 | | - def test_policy_yaml_config_loading(self): |
66 | | - """Policy can be constructed from a YAML config file.""" |
67 | | - yaml_content = """ |
68 | | - worker_params: |
69 | | - model: "yaml-test-model-9876" |
70 | | - tensor_parallel_size: 1234 |
71 | | - pipeline_parallel_size: 5678 |
72 | | - enforce_eager: true |
73 | | -
|
74 | | - sampling_overrides: |
75 | | - num_samples: 2468 |
76 | | - guided_decoding: true |
77 | | - max_tokens: 1357 |
78 | | -
|
79 | | - available_devices: "yaml-test-device-xyz" |
80 | | - """ |
81 | | - |
82 | | - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: |
83 | | - f.write(yaml_content) |
84 | | - f.flush() |
85 | | - |
86 | | - with open(f.name, "r") as yaml_file: |
87 | | - config = yaml.safe_load(yaml_file) |
88 | | - |
89 | | - policy = Policy(**config) |
90 | | - |
91 | | - self.assertEqual(policy.worker_params.model, "yaml-test-model-9876") |
92 | | - self.assertEqual(policy.worker_params.tensor_parallel_size, 1234) |
93 | | - self.assertEqual(policy.worker_params.pipeline_parallel_size, 5678) |
94 | | - self.assertTrue(policy.worker_params.enforce_eager) |
95 | | - |
96 | | - self.assertEqual(policy.sampling_overrides.num_samples, 2468) |
97 | | - self.assertTrue(policy.sampling_overrides.guided_decoding) |
98 | | - self.assertEqual(policy.sampling_overrides.max_tokens, 1357) |
99 | | - |
100 | | - self.assertEqual(policy.available_devices, "yaml-test-device-xyz") |
101 | | - |
102 | | - def test_workerconfig_ignores_invalid_keys(self): |
103 | | - """WorkerConfig.from_dict ignores unexpected keys.""" |
104 | | - worker_dict = { |
105 | | - "model": "custom-model", |
106 | | - "tensor_parallel_size": 2, |
107 | | - "invalid_key_123": "should be ignored", |
108 | | - } |
109 | | - |
110 | | - config = WorkerConfig.from_dict(worker_dict) |
111 | | - |
112 | | - self.assertEqual(config.model, "custom-model") |
113 | | - self.assertEqual(config.tensor_parallel_size, 2) |
114 | | - self.assertFalse(hasattr(config, "invalid_key_123")) |
115 | | - |
116 | | - |
117 | | -if __name__ == "__main__": |
118 | | - unittest.main() |
0 commit comments