Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,21 @@ class Policy(PolicyInterface):
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
policy_worker: "PolicyWorker" = None
store: MultiProcessStore | None = None

def __post_init__(self):
self._run_task: asyncio.Task | None = None
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.weights_version: int = 0

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Policy"],
*,
process_config: ProcessConfig,
config: PolicyConfig,
store: MultiProcessStore | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does spawn services need to know to pass this in?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it has to be passed in from the top level so you can use the same for Trainer.

**kwargs,
) -> "Policy":
# Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
Expand All @@ -132,7 +135,11 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
policy = await policy_proc.spawn(
actor_name, cls, config=config, policy_worker=workers
actor_name,
cls,
config=config,
policy_worker=workers,
store=store,
)
policy._policy_proc = policy_proc
policy._worker_procs = worker_procs
Expand Down Expand Up @@ -160,7 +167,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
async def setup(self):
# Set up policy_worker
assert self.policy_worker is not None, "Policy worker should not be None"
await self.policy_worker.setup.call()
await self.policy_worker.setup.call(store=self.store)

self.request_id = 0
self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
Expand Down Expand Up @@ -313,9 +320,20 @@ async def run(self):
fut.set_result(request_output.outputs)

@endpoint
async def update_weights(self):
async def update_weights(self) -> int:
"""Update the policy weights."""
pass
# Wait for all current requests to finish, then publish model weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a check that the new version exists on the store?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just thinking that it would fail when trying to lookup the key and that would be obvious enough for now, but I could make it more explicit.

futures = [fut for _, fut in self.requests.values()]
if futures:
await asyncio.gather(*futures)
await self.policy_worker.update.call()
self.weights_version += 1
return self.weights_version

@endpoint
async def get_version(self) -> int:
"""Get the current policy version."""
return self.weights_version

@endpoint
async def stop(self):
Expand Down
Loading