Skip to content

Commit ffacf70

Browse files
joecummingsallenwang28
authored andcommitted
Policy cleaner launch / setup (#401)
1 parent d2130ae commit ffacf70

File tree

2 files changed

+45
-53
lines changed

2 files changed

+45
-53
lines changed

src/forge/actors/generator.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,27 @@
1616

1717
import torch
1818
import torchstore as ts
19+
20+
from forge.actors._torchstore_utils import (
21+
extract_param_name,
22+
get_dcp_whole_state_dict_key,
23+
get_param_key,
24+
get_param_prefix,
25+
load_tensor_from_dcp,
26+
)
27+
28+
from forge.controller import (
29+
ForgeActor,
30+
get_proc_mesh,
31+
host_mesh_from_proc,
32+
stop_proc_mesh,
33+
)
34+
from forge.data_models.completion import Completion
35+
from forge.data_models.prompt import to_prompt
36+
from forge.env import TORCHSTORE_USE_RDMA
37+
from forge.observability.metrics import record_metric, Reduce
38+
from forge.observability.perf_tracker import Tracer
39+
from forge.types import ProcessConfig
1940
from monarch.actor import current_rank, endpoint, ProcMesh
2041
from vllm.config import VllmConfig
2142

@@ -40,27 +61,6 @@
4061
from vllm.v1.structured_output import StructuredOutputManager
4162
from vllm.worker.worker_base import WorkerWrapperBase
4263

43-
from forge.actors._torchstore_utils import (
44-
extract_param_name,
45-
get_dcp_whole_state_dict_key,
46-
get_param_key,
47-
get_param_prefix,
48-
load_tensor_from_dcp,
49-
)
50-
51-
from forge.controller import (
52-
ForgeActor,
53-
get_proc_mesh,
54-
host_mesh_from_proc,
55-
stop_proc_mesh,
56-
)
57-
from forge.data_models.completion import Completion
58-
from forge.data_models.prompt import to_prompt
59-
from forge.env import TORCHSTORE_USE_RDMA
60-
from forge.observability.metrics import record_metric, Reduce
61-
from forge.observability.perf_tracker import Tracer
62-
from forge.types import ProcessConfig
63-
6464
logger = logging.getLogger(__name__)
6565
logger.setLevel(logging.INFO)
6666

tests/unit_tests/test_generator_config.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def test_generator_default_initialization(self):
3939
# Default factories
4040
self.assertIsInstance(generator.engine_args, EngineArgs)
4141
self.assertIsInstance(generator.sampling_params, SamplingParams)
42-
self.assertIsNone(generator.available_devices)
4342

4443
# Worker defaults
4544
self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B")
@@ -58,46 +57,43 @@ def test_generator_default_initialization(self):
5857
reason="Import error, likely due to missing dependencies on CI.",
5958
)
6059
def test_generator_with_dict_configs(self):
61-
"""Generator accepts dicts for engine_config and sampling_config, including nested dicts."""
6260
from forge.actors.generator import Generator
6361
from vllm.engine.arg_utils import EngineArgs
6462
from vllm.sampling_params import SamplingParams
6563

66-
# Test with nested dict structure
6764
engine_dict = {
68-
"model": "test-model-6789",
69-
"tensor_parallel_size": 7777,
70-
"pipeline_parallel_size": 8888,
65+
"model": "Qwen/Qwen3-0.6B",
66+
"tensor_parallel_size": 1,
67+
"pipeline_parallel_size": 1,
7168
"enforce_eager": True,
72-
"gpu_memory_utilization": 0.9,
73-
"max_model_len": 4096,
69+
"gpu_memory_utilization": 0.1,
70+
"max_model_len": 1024,
7471
}
7572

7673
sampling_dict = {
77-
"n": 1357,
78-
"max_tokens": 2468,
74+
"n": 2,
75+
"max_tokens": 32,
7976
}
8077

8178
generator = Generator(
8279
engine_args=engine_dict,
8380
sampling_params=sampling_dict,
84-
available_devices="test-gpu-device-abcd",
8581
)
8682

8783
self.assertIsInstance(generator.engine_args, EngineArgs)
8884
self.assertIsInstance(generator.sampling_params, SamplingParams)
8985

9086
# Test basic fields
91-
self.assertEqual(generator.engine_args.model, "test-model-6789")
92-
self.assertEqual(generator.engine_args.tensor_parallel_size, 7777)
93-
self.assertEqual(generator.engine_args.pipeline_parallel_size, 8888)
94-
self.assertEqual(generator.engine_args.gpu_memory_utilization, 0.9)
95-
self.assertEqual(generator.engine_args.max_model_len, 4096)
87+
self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B")
88+
self.assertEqual(generator.engine_args.tensor_parallel_size, 1)
89+
self.assertEqual(generator.engine_args.pipeline_parallel_size, 1)
90+
self.assertEqual(generator.engine_args.gpu_memory_utilization, 0.1)
91+
self.assertEqual(generator.engine_args.max_model_len, 1024)
9692
self.assertTrue(generator.engine_args.enforce_eager)
9793
self.assertTrue(generator.engine_args._is_v1_supported_oracle())
9894

99-
self.assertEqual(generator.sampling_params.n, 1357)
100-
self.assertEqual(generator.sampling_params.max_tokens, 2468)
95+
self.assertEqual(generator.sampling_params.n, 2)
96+
self.assertEqual(generator.sampling_params.max_tokens, 32)
10197

10298
@pytest.mark.skipif(
10399
_import_error(),
@@ -109,16 +105,14 @@ def test_generator_yaml_config_loading(self):
109105

110106
yaml_content = """
111107
engine_args:
112-
model: "yaml-test-model-9876"
113-
tensor_parallel_size: 1234
114-
pipeline_parallel_size: 5678
108+
model: "Qwen/Qwen3-0.6B"
109+
tensor_parallel_size: 1
110+
pipeline_parallel_size: 1
115111
enforce_eager: true
116112
117113
sampling_params:
118-
n: 2468
119-
max_tokens: 1357
120-
121-
available_devices: "yaml-test-device-xyz"
114+
n: 2
115+
max_tokens: 32
122116
"""
123117

124118
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
@@ -129,16 +123,14 @@ def test_generator_yaml_config_loading(self):
129123
config = yaml.safe_load(yaml_file)
130124

131125
generator = Generator(**config)
132-
self.assertEqual(generator.engine_args.model, "yaml-test-model-9876")
133-
self.assertEqual(generator.engine_args.tensor_parallel_size, 1234)
134-
self.assertEqual(generator.engine_args.pipeline_parallel_size, 5678)
126+
self.assertEqual(generator.engine_args.model, "Qwen/Qwen3-0.6B")
127+
self.assertEqual(generator.engine_args.tensor_parallel_size, 1)
128+
self.assertEqual(generator.engine_args.pipeline_parallel_size, 1)
135129
self.assertTrue(generator.engine_args.enforce_eager)
136130
self.assertTrue(generator.engine_args._is_v1_supported_oracle())
137131

138-
self.assertEqual(generator.sampling_params.n, 2468)
139-
self.assertEqual(generator.sampling_params.max_tokens, 1357)
140-
141-
self.assertEqual(generator.available_devices, "yaml-test-device-xyz")
132+
self.assertEqual(generator.sampling_params.n, 2)
133+
self.assertEqual(generator.sampling_params.max_tokens, 32)
142134

143135

144136
if __name__ == "__main__":

0 commit comments

Comments
 (0)