7
7
# pyre-unsafe
8
8
9
9
import pickle
10
- from typing import Any , Callable , Coroutine , Iterable , List , TYPE_CHECKING
10
+ from typing import Any , Callable , cast , Coroutine , Iterable , List , TYPE_CHECKING
11
11
12
12
import monarch
13
13
import pytest
@@ -57,6 +57,12 @@ async def allocate() -> ProcMesh:
57
57
58
58
59
59
class MyActor :
60
+ def __init__ (self ) -> None :
61
+ # Note: for the same actor, its rank on the root mesh could be different
62
+ # from its rank on the mesh it is cast to. This is because the cast
63
+ # mesh could be a sliced mesh.
64
+ self ._rank_on_root_mesh : int = - 1
65
+
60
66
async def handle (
61
67
self ,
62
68
mailbox : Mailbox ,
@@ -68,8 +74,21 @@ async def handle(
68
74
local_state : Iterable [Any ],
69
75
response_port : "PortProtocol[Any]" ,
70
76
) -> None :
71
- assert rank is not None
72
- response_port .send (f"rank: { rank } " )
77
+ match method :
78
+ case MethodSpecifier .Init ():
79
+ # Since this actor is spawn from the root proc mesh, the rank
80
+ # passed from init should be the rank on the root mesh.
81
+ self ._rank_on_root_mesh = rank
82
+ response_port .send (None )
83
+ return None
84
+ case MethodSpecifier .ReturnsResponse (name = _):
85
+ response_port .send (self ._rank_on_root_mesh )
86
+ return None
87
+ case MethodSpecifier .ExplicitPort (name = _):
88
+ response_port .exception (
89
+ NotImplementedError ("ExplicitPort is not supported yet" )
90
+ )
91
+ return None
73
92
74
93
75
94
# TODO - re-enable after resolving T232206970
@@ -95,35 +114,70 @@ async def run() -> None:
95
114
run ()
96
115
97
116
98
- async def verify_cast (
117
+ async def spawn_actor_mesh (proc_mesh : ProcMesh ) -> PythonActorMesh :
118
+ actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
119
+ # init actors to record their root ranks
120
+ receiver : PortReceiver
121
+ handle , receiver = proc_mesh .client .open_port ()
122
+ port_ref = handle .bind ()
123
+
124
+ message = PythonMessage (
125
+ PythonMessageKind .CallMethod (MethodSpecifier .Init (), port_ref ),
126
+ pickle .dumps (None ),
127
+ )
128
+ actor_mesh .cast (Selection .all (), message )
129
+ # wait for init to complete
130
+ for _ in range (len (actor_mesh .shape .ndslice )):
131
+ await receiver .recv_task ()
132
+
133
+ return actor_mesh
134
+
135
+
136
+ async def cast_to_call (
137
+ actor_mesh : PythonActorMesh | PythonActorMeshRef ,
138
+ mailbox : Mailbox ,
139
+ message : PythonMessage ,
140
+ ) -> None :
141
+ sel = Selection .all ()
142
+ if isinstance (actor_mesh , PythonActorMesh ):
143
+ actor_mesh .cast (sel , message )
144
+ elif isinstance (actor_mesh , PythonActorMeshRef ):
145
+ actor_mesh .cast (mailbox , sel , message )
146
+
147
+
148
+ async def verify_cast_to_call (
99
149
actor_mesh : PythonActorMesh | PythonActorMeshRef ,
100
150
mailbox : Mailbox ,
101
- cast_ranks : List [int ],
151
+ root_ranks : List [int ],
102
152
) -> None :
103
153
receiver : PortReceiver
104
154
handle , receiver = mailbox .open_port ()
105
155
port_ref = handle .bind ()
106
156
157
+ # Now send the real message
107
158
message = PythonMessage (
108
159
PythonMessageKind .CallMethod (MethodSpecifier .ReturnsResponse ("echo" ), port_ref ),
109
160
pickle .dumps ("ping" ),
110
161
)
111
- sel = Selection .from_string ("*" )
112
- if isinstance (actor_mesh , PythonActorMesh ):
113
- actor_mesh .cast (sel , message )
114
- elif isinstance (actor_mesh , PythonActorMeshRef ):
115
- actor_mesh .cast (mailbox , sel , message )
162
+ await cast_to_call (actor_mesh , mailbox , message )
116
163
117
164
rcv_ranks = []
118
- for _ in range (len (cast_ranks )):
165
+ for _ in range (len (root_ranks )):
119
166
message = await receiver .recv_task ()
120
167
result_kind = message .kind
121
168
assert isinstance (result_kind , PythonMessageKind .Result )
122
- rank = result_kind .rank
123
- assert rank is not None
124
- rcv_ranks .append (rank )
125
- rcv_ranks .sort ()
126
- assert rcv_ranks == cast_ranks
169
+ cast_rank = result_kind .rank
170
+ assert cast_rank is not None
171
+ root_rank = cast (int , pickle .loads (message .message ))
172
+ rcv_ranks .append ((cast_rank , root_rank ))
173
+ rcv_ranks .sort (key = lambda pair : pair [0 ])
174
+ recv_cast_ranks , recv_root_ranks = zip (* rcv_ranks )
175
+ assert recv_root_ranks == tuple (
176
+ root_ranks
177
+ ), f"recv_root_ranks={ recv_root_ranks } , root_ranks={ tuple (root_ranks )} "
178
+ assert recv_cast_ranks == tuple (
179
+ range (len (root_ranks ))
180
+ ), f"recv_cast_ranks={ recv_cast_ranks } , root_ranks={ tuple (root_ranks )} "
127
181
# verify no more messages are received
128
182
with pytest .raises (TimeoutError ):
129
183
await receiver .recv_task ().with_timeout (1 )
@@ -136,8 +190,8 @@ async def test_cast_handle() -> None:
136
190
@run_on_tokio
137
191
async def run () -> None :
138
192
proc_mesh = await allocate ()
139
- actor_mesh = await proc_mesh . spawn_nonblocking ( "test" , MyActor )
140
- await verify_cast (actor_mesh , proc_mesh .client , list (range (3 * 8 * 8 )))
193
+ actor_mesh = await spawn_actor_mesh ( proc_mesh )
194
+ await verify_cast_to_call (actor_mesh , proc_mesh .client , list (range (3 * 8 * 8 )))
141
195
142
196
await proc_mesh .stop_nonblocking ()
143
197
@@ -151,9 +205,11 @@ async def test_cast_ref() -> None:
151
205
@run_on_tokio
152
206
async def run () -> None :
153
207
proc_mesh = await allocate ()
154
- actor_mesh = await proc_mesh . spawn_nonblocking ( "test" , MyActor )
208
+ actor_mesh = await spawn_actor_mesh ( proc_mesh )
155
209
actor_mesh_ref = actor_mesh .bind ()
156
- await verify_cast (actor_mesh_ref , proc_mesh .client , list (range (3 * 8 * 8 )))
210
+ await verify_cast_to_call (
211
+ actor_mesh_ref , proc_mesh .client , list (range (3 * 8 * 8 ))
212
+ )
157
213
158
214
await proc_mesh .stop_nonblocking ()
159
215
@@ -184,7 +240,7 @@ async def verify_slice(
184
240
assert (
185
241
sliced_shape .ranks () == replica_0_ranks + replica_1_ranks
186
242
), f"left is { sliced_shape .ranks ()} "
187
- await verify_cast (sliced_mesh , mailbox , sliced_shape .ranks ())
243
+ await verify_cast_to_call (sliced_mesh , mailbox , sliced_shape .ranks ())
188
244
189
245
assert sliced_shape .labels == ["replicas" , "hosts" , "gpus" ]
190
246
assert sliced_shape .ndslice .sizes == [2 , 4 , 3 ]
@@ -224,7 +280,8 @@ async def test_slice_actor_mesh_handle() -> None:
224
280
@run_on_tokio
225
281
async def run () -> None :
226
282
proc_mesh = await allocate ()
227
- actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
283
+ actor_mesh = await spawn_actor_mesh (proc_mesh )
284
+
228
285
await verify_slice (actor_mesh , proc_mesh .client )
229
286
230
287
await proc_mesh .stop_nonblocking ()
@@ -239,7 +296,8 @@ async def test_slice_actor_mesh_ref() -> None:
239
296
@run_on_tokio
240
297
async def run () -> None :
241
298
proc_mesh = await allocate ()
242
- actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
299
+ actor_mesh = await spawn_actor_mesh (proc_mesh )
300
+
243
301
actor_mesh_ref = actor_mesh .bind ()
244
302
await verify_slice (actor_mesh_ref , proc_mesh .client )
245
303
0 commit comments