Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def set_env(self, addr: str, port: str):
os.environ["MASTER_PORT"] = port

@classmethod
async def launch(cls, **kwargs) -> "ForgeActor":
async def launch(cls, *args, **kwargs) -> "ForgeActor":
"""Provisions and deploys a new actor.

This method is used by `Service` to provision a new replica.
Expand All @@ -183,9 +183,8 @@ async def launch(cls, **kwargs) -> "ForgeActor":

proc_mesh = await get_proc_mesh(process_config=cfg)

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
actor = await proc_mesh.spawn(actor_name, cls, **kwargs)
actor = await proc_mesh.spawn(actor_name, cls, *args, **kwargs)
actor._proc_mesh = proc_mesh

if hasattr(proc_mesh, "_hostname") and hasattr(proc_mesh, "_port"):
Expand All @@ -195,7 +194,7 @@ async def launch(cls, **kwargs) -> "ForgeActor":
return actor

@classmethod
async def as_actor(cls: Type[T], **actor_kwargs) -> T:
async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T:
"""
Spawns a single actor using the configuration stored in `.options()`, or defaults.

Expand All @@ -204,7 +203,7 @@ async def as_actor(cls: Type[T], **actor_kwargs) -> T:
If no configuration was stored, defaults to a single process with no GPU.
"""
logger.info("Spawning single actor %s", cls.__name__)
actor = await cls.launch(**actor_kwargs)
actor = await cls.launch(*args, **actor_kwargs)
return actor

@classmethod
Expand Down
20 changes: 18 additions & 2 deletions tests/unit_tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,18 @@
class Counter(ForgeActor):
"""Test actor that maintains a counter with various endpoints."""

def __init__(self, v: int):
def __init__(self, v: int, *args, **kwargs):
self.v = v
self.args = args
self.kwargs = kwargs

@endpoint
async def get_args(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove these tests, i think it's fine without

return self.args

@endpoint
async def get_kwargs(self):
return self.kwargs

@endpoint
async def incr(self):
Expand Down Expand Up @@ -83,13 +93,19 @@ def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica:
@pytest.mark.timeout(10)
async def test_as_actor_with_kwargs_config():
"""Test spawning a single actor with passing configs through kwargs."""
actor = await Counter.options(procs=1).as_actor(v=5)
actor = await Counter.options(procs=1).as_actor(
5, "hello", k=0
) # "hello" goes to args, k=0 goes to kwargs

try:
assert await actor.value.choose() == 5

# Test increment
await actor.incr.choose()
assert await actor.value.choose() == 6
# Check that the positional/keyword arguments were passed correctly to the actor
assert await actor.get_args.choose() == ("hello",)
assert await actor.get_kwargs.choose() == {"k": 0}

finally:
await Counter.shutdown(actor)
Expand Down
Loading