Skip to content

Commit 32ac7f3

Browse files
committed
refactor
1 parent 5f7cf3c commit 32ac7f3

File tree

2 files changed

+29
-40
lines changed

2 files changed

+29
-40
lines changed

src/forge/actors/policy.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ class Policy(PolicyInterface):
107107
lora_request: LoRARequest | None = None
108108
tokenization_kwargs: dict = field(default_factory=dict)
109109
policy_worker: "PolicyWorker" = None
110-
store = None
111110

112111
def __post_init__(self):
113112
self._run_task: asyncio.Task | None = None
@@ -121,7 +120,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
121120
*,
122121
process_config: ProcessConfig,
123122
config: PolicyConfig,
124-
store=None,
125123
**kwargs,
126124
) -> "Policy":
127125
# Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
@@ -172,7 +170,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
172170
async def setup(self):
173171
# Set up policy_worker
174172
assert self.policy_worker is not None, "Policy worker should not be None"
175-
await self.policy_worker.setup.call(store=self.store)
173+
await self.policy_worker.setup.call()
176174

177175
self.request_id = 0
178176
self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
@@ -395,7 +393,7 @@ def __post_init__(self):
395393
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)
396394

397395
@endpoint
398-
async def setup(self, store=None):
396+
async def setup(self):
399397
# TODO: remove ["gpus"] when monarch implements a flat rank
400398
self.rank = current_rank()["gpus"]
401399
self.worker = self.setup_worker()

tests/integration_tests/test_policy_update.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,12 @@ def get_configs(
186186

187187

188188
@pytest_asyncio.fixture(scope="session")
189-
async def llama3_torchstore_setup():
189+
async def setup_test():
190190
"""
191-
Pytest fixture to load Llama 3.1 8B-Instruct. We use the loaded state dict as SOT for validation.
192-
Uses session scope so it's only called once when both tests are run.
191+
Pytest fixture to load Llama 3.1 8B-Instruct. We use the loaded state dict
192+
as the SOT for validation. Uses session scope so it's only called once
193+
across UT.
193194
"""
194-
print("=== PHASE 1: Writing Llama 3.1 8B-Instruct to TorchStore ===")
195-
196-
store = await ts.initialize()
197-
198195
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
199196

200197
# Load the model from local path - using device_map="auto" for efficient loading
@@ -207,76 +204,70 @@ async def llama3_torchstore_setup():
207204

208205
original_state_dict = model.state_dict()
209206
print(f"Original state dict has {len(original_state_dict)} parameters")
210-
converted_state_dict = convert_state_dict(original_state_dict)
211-
print(f"Converted state dict has {len(converted_state_dict)} parameters")
207+
hf_state_dict = convert_state_dict(original_state_dict)
208+
print(f"Converted state dict has {len(hf_state_dict)} parameters")
212209

213-
return store, converted_state_dict
210+
return hf_state_dict
214211

215212

216213
async def run_rl_trainer(worker_size) -> None:
217214
"""
218-
1. Spawn the trainer.
219-
2. Inject torchstore references via setup call.
220-
2. Call push weights.
215+
Spawn the RL trainer
216+
Args:
217+
worker_size: Number of workers/procs.
221218
"""
222219
cfg: DictConfig = OmegaConf.load("apps/rl/llama3_8b.yaml")
223220
rl_trainer = await spawn_service(
224-
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
221+
ServiceConfig(procs_per_replica=worker_size, with_gpus=True, num_replicas=1),
225222
RLTrainer,
226223
**cfg.trainer,
227224
)
228225
# Push the weights to torchstore
229226
await rl_trainer.push_weights.choose()
230227

231228

232-
async def run_policy_integration(store, worker_size) -> Dict[str, torch.Tensor]:
229+
async def run_policy_integration(worker_size) -> Dict[str, torch.Tensor]:
233230
"""
234-
Common helper function to test Policy integration with different GPU configurations.
231+
Launch the policy service.
235232
236233
Args:
237234
store: TorchStore instance
238-
original_state_dict: Original state dict for validation
239-
num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel)
235+
worker_size: Number of workers/procs (2+ for tensor parallel)
240236
"""
241-
print(f"=== PHASE 2: Testing Policy Integration (Workers: {worker_size}) ===")
237+
print(f"=== PHASE 2: Launching Policy Engine (Workers: {worker_size}) ===")
242238

243239
policy_config, service_config = get_configs(
244240
worker_size=worker_size, model_name="meta-llama/Llama-3.1-8B-Instruct"
245241
)
246-
policy = await spawn_service(
247-
service_config, Policy, config=policy_config, store=store
248-
)
242+
policy = await spawn_service(service_config, Policy, config=policy_config)
249243

250244
# Policy engine start with default version 0 that gets incremented.
251245
print("Calling Policy.update() to load weights from torchstore...")
252246
await policy.update_weights.call()
253247
print(
254248
"Successfully called Policy.update_weights() to load weights from torchstore!"
255249
)
256-
# We get the result as a list.
257-
#results = await policy._get_model_params.call()
258-
#assert len(results) == 1
259-
#print("Successfully got model state dict after update")
260-
#return results[0]
261-
return {}
250+
results = await policy._get_model_params.call()
251+
assert len(results) == 1
252+
print("Successfully got model state dict after update")
253+
return results[0]
262254

263255

264256
@pytest.mark.asyncio
265257
@requires_cuda
266-
async def test_llama3_policy_update_single():
258+
async def test_llama3_policy_update_single(setup_test):
267259
print("Starting Llama 3 8B torchstore test (single GPU)...")
268260

269-
# store, original_state_dict = llama3_torchstore_setup
270261
await ts.initialize()
262+
expected_state_dict = setup_test
271263
await run_rl_trainer(worker_size=1)
272-
loaded_state_dict = await run_policy_integration(None, worker_size=1)
273-
assert False, "Planned failure"
264+
loaded_state_dict = await run_policy_integration(worker_size=1)
274265

275266
# validating for single resource case.
276-
# validate_loaded_tensors_equals_original(
277-
# loaded_state_dict, original_state_dict, tensor_parallel_size=0, rank=0
278-
# )
279-
267+
validate_loaded_tensors_equals_original(
268+
loaded_state_dict, expected_state_dict, tensor_parallel_size=0, rank=0
269+
)
280270
print(
281271
"Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"
282272
)
273+
assert False, "Planned failure"

0 commit comments

Comments
 (0)