Skip to content

Commit 2793033

Browse files
committed
docs
1 parent fe0d924 commit 2793033

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/integration_tests/test_policy_update.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,13 @@ async def run_policy_integration(worker_size) -> Dict[str, torch.Tensor]:
262262
@pytest.mark.asyncio
263263
@requires_cuda
264264
async def test_llama3_policy_update_single(setup_test):
265-
print("Starting Llama 3 8B torchstore test (single GPU)...")
266-
265+
"""
266+
1. Loads weights from HF model into in-memory state-dict (source of truth)
267+
2. Initializes RLTrainer, make the weights available in torchstore.
268+
3. Initializes Policy, and calls update_weights() to load weights from torchstore.
269+
4. Validate the policy weights against source of truth.
270+
"""
271+
logger.info("Starting Llama 3 8B torchstore test (single GPU)...")
267272
await ts.initialize()
268273
expected_state_dict = setup_test
269274
await run_rl_trainer(worker_size=1)
@@ -273,7 +278,7 @@ async def test_llama3_policy_update_single(setup_test):
273278
validate_loaded_tensors_equals_original(
274279
loaded_state_dict, expected_state_dict, tensor_parallel_size=0, rank=0
275280
)
276-
print(
281+
logger.info(
277282
"Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"
278283
)
279284
assert False, "Planned failure"

0 commit comments

Comments
 (0)