Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,14 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
await stop_proc_mesh(actor._generator_proc)

@endpoint
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[Generator] save model parameters for testing.")
await self.worker._test_save_model_params.call()
async def save_model_params(self):
"""Used for debugging purpose. Save model parameters before weight update."""
await self.worker.save_model_params.call()

@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[Generator] start validating model parameters.")
return await self.worker._test_validate_model_params.call(validate_fn)
async def validate_model_params(self, validate_fn):
"""Used for debugging purpose. Validate saved params using validate_fn."""
return await self.worker.validate_model_params.call(validate_fn)


@dataclass
Expand All @@ -514,8 +512,6 @@ class GeneratorWorker(ForgeActor):
"""

vllm_config: VllmConfig
# TODO: Remove below param
_test_prev_params = {}

@endpoint
async def setup(self):
Expand Down Expand Up @@ -605,20 +601,19 @@ async def update_weights(self, version: int) -> None:
t.stop()

@endpoint
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[GeneratorWorker] save model parameters for testing.")
async def save_model_params(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this method return all the model params? Instead of saving them to the model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I am late to the party. The short answer is no: the workers are remote and it's not feasible to send all the parameters through RPC.

"""Used for debugging purposes. Save model parameters before weight update."""
self._debug_saved_params = {}
for name, param in self.worker.model_runner.model.named_parameters():
self._test_prev_params[name] = param.detach().cpu()
self._debug_saved_params[name] = param.detach().cpu()
logger.info(
"[GeneratorWorker] finished saving model parameters, len = %d",
len(self._test_prev_params),
len(self._debug_saved_params),
)

@endpoint
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[GeneratorWorker] start validating model parameters.")
async def validate_model_params(self, validate_fn):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually - could we expose an endpoint to get the underlying model? Then these methods could be entirely deleted, no?

Sorry if I'm misunderstanding. But I don't like having a debug param saved on this model :/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a fair tradeoff to expose certain method for debugging purpose.
The beauty of exposing a validator function here is it's a very flexible and reliable way to poke at the property on each worker shard and saves the cost of message SerDer.
Another restriction here is Policy is behind Service interface....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saves the cost of message SerDer.

Do we imagine testing w/ these methods at a substantial scale?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saves the cost of message SerDer.

Do we imagine testing w/ these methods at a substantial scale?

My hunch is so long as we're still iterating on weight sync strategy, having some capability to poke into the states across multiple workers under multi-node, multi-replica run might be super valuable.

"""Used for debugging purposes. Validate saved params using validate_fn."""
return validate_fn(
self._test_prev_params, self.worker.model_runner.model, logger
self._debug_saved_params, self.worker.model_runner.model, logger
)
8 changes: 4 additions & 4 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,18 +254,18 @@ async def test_sanity_check(self, _setup_and_teardown):
# Setting everything to zero
await rl_trainer.zero_out_model_states.call()
await rl_trainer.push_weights.call(policy_version=v1)
await policy._test_save_model_params.fanout()
await policy.save_model_params.fanout()

# Sanity check that before update all the tests pass
all_errs = await policy._test_validate_model_params.fanout(
all_errs = await policy.validate_model_params.fanout(
_test_validate_params_unchanged
)
for errs in all_errs:
for _, e in errs.items():
assert not e, f"Validation failed with exception: {e}"

await policy.update_weights.fanout(version=v1)
all_errs = await policy._test_validate_model_params.fanout(
all_errs = await policy.validate_model_params.fanout(
_test_validate_params_all_zeros
)
for errs in all_errs:
Expand All @@ -274,7 +274,7 @@ async def test_sanity_check(self, _setup_and_teardown):

# Reloading v0, getting back original weights
await policy.update_weights.fanout(version=v0)
all_errs = await policy._test_validate_model_params.fanout(
all_errs = await policy.validate_model_params.fanout(
_test_validate_params_unchanged
)
for errs in all_errs:
Expand Down
Loading