Skip to content

Commit a31bd76

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
explore using a side stream for two data-dependent all_to_all_single comms (#3440)
Summary: Pull Request resolved: #3440 # context * table-wise-row-wise (TWRW) sharding takes the advantage of high bandwidth intra-node comms for the data-intensive row-wise sharded embedding table pooling. * it uses [two output dist components](https://github.com/meta-pytorch/torchrec/blob/release/v1.3.0/torchrec/distributed/sharding/twrw_sharding.py#L479-L490): intra-node dist and cross-node dist. The cross-node dist relies on the data/result from intra-node dist. * This data dependency actually creates a blocking situation on the main cuda (compute) stream, as shown below (nccl:_reduce_scatter for the intra-node dist, nccl:_all_to_all for the cross-node dist) {F1982557282} # experiment * the correct approach is to use a side stream to process the data-dependent comms * without side stream: [trace](https://drive.google.com/file/d/1lpa-NrBD0IWcpskdN1Lwiu0XcSTe01bW/view?usp=sharing) the first comms is blocking the main stream execution {F1982557422} * with side stream: [trace](https://drive.google.com/file/d/1FqNpq4yMx9H6vL47S8KX5dvk2PJv_QGa/view?usp=sharing) both comms are non-blocking on the main stream {F1982557381} Reviewed By: spmex Differential Revision: D82002643 fbshipit-source-id: 00ee3e7b20f4ed0b799b3c8a49a3a5f7566f87c1
1 parent eaea63d commit a31bd76

File tree

1 file changed

+81
-2
lines changed

1 file changed

+81
-2
lines changed

torchrec/distributed/benchmark/benchmark_comms.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def a2a_async_base(
148148
async_op=True,
149149
)
150150

151-
with record_function("## comms validation ##"):
151+
with record_function("## comms pre-check ##"):
152152
# pre-check is performed before comms' done
153153
pre_checks = _validate(post_comms, ctx).to("cpu", non_blocking=True)
154154
# need this cuda.event to record the device-to-host data transfer
@@ -159,13 +159,15 @@ def a2a_async_base(
159159
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
160160

161161
ev_d2h.synchronize() # make sure the pre_checks is available from cpu side
162-
with record_function(f"## post-comms compute: pre-check-{pre_checks}##"):
162+
with record_function(f"## comms check and pre-check: {pre_checks} ##"):
163163
# assertion fails without wait(), this wait() makes the main cuda stream wait
164164
# for the comms to finish, so the post-comms compute will be blocked until
165165
# the comms is done
166166
req.wait()
167167
checks = _validate(post_comms, ctx).to("cpu", non_blocking=True)
168168
ev_d2h.record() # record the device-to-host data transfer
169+
170+
with record_function("## post-comms compute ##"):
169171
post_comms = _compute(
170172
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
171173
)
@@ -176,6 +178,81 @@ def a2a_async_base(
176178
assert checks
177179

178180

181+
# all_to_all_single with sync and single stream
182+
def a2a_async_twice(
183+
_batch_inputs: List[Dict[str, Any]],
184+
dim: int,
185+
num_mul: int,
186+
num_concat: int,
187+
ctx: MultiProcessContext,
188+
) -> None:
189+
with record_function("## pre-comms compute ##"):
190+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
191+
192+
with record_function("## pre-allocation ##"):
193+
# use zeros instead of empty to make sure no previous data used
194+
post_comms1 = torch.zeros_like(pre_comms)
195+
post_comms2 = torch.zeros_like(pre_comms)
196+
197+
with record_function("## comms1 ##"):
198+
req1 = dist.all_to_all_single(
199+
output=post_comms1,
200+
input=pre_comms,
201+
group=ctx.pg,
202+
async_op=True,
203+
)
204+
205+
with record_function("## comms1 pre-validation ##"):
206+
# pre-check is performed before comms' done
207+
pre_checks1 = _validate(post_comms1, ctx).to("cpu", non_blocking=True)
208+
# need this cuda.event to record the device-to-host data transfer
209+
ev_d2h = torch.cuda.Event()
210+
ev_d2h.record()
211+
212+
with record_function("## comms2 ##"):
213+
side_stream = torch.cuda.Stream()
214+
post_comms2.record_stream(side_stream)
215+
with torch.cuda.stream(side_stream):
216+
req1.wait() # let the side stream wait for comms1 to finish
217+
pre_comms = torch.sigmoid(post_comms1) + ctx.rank
218+
req2 = dist.all_to_all_single(
219+
output=post_comms2,
220+
input=pre_comms,
221+
group=ctx.pg,
222+
async_op=True,
223+
)
224+
225+
with record_function("## irrelevant compute1 ##"):
226+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
227+
228+
with record_function("## comms2 pre-validation ##"):
229+
# pre-check is performed before comms' done, actually even before comms2 starts
230+
pre_checks2 = _validate(post_comms2, ctx).to("cpu", non_blocking=True)
231+
ev_d2h.record() # record the device-to-host data transfer
232+
233+
with record_function("## irrelevant compute2 ##"):
234+
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
235+
236+
ev_d2h.synchronize() # make sure the pre_checks is available from cpu side
237+
with record_function(f"## comms1 checks and pre-checks1 {pre_checks1} ##"):
238+
req1.wait() # let the main stream wait for comms1 to finish
239+
checks1 = _validate(post_comms1, ctx).to("cpu", non_blocking=True)
240+
with record_function(f"## comms2 checks and pre-checks2 {pre_checks2} ##"):
241+
req2.wait() # let the main stream wait for comms2 to finish
242+
checks2 = _validate(post_comms2, ctx).to("cpu", non_blocking=True)
243+
ev_d2h.record() # record the device-to-host data transfer
244+
245+
with record_function("## post-comms comput ##"):
246+
post_comms2 = _compute(
247+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms2[0]
248+
)
249+
250+
with record_function("## assert ##"):
251+
# again, make sure the device-to-host data transfer is done before the assertion
252+
ev_d2h.synchronize()
253+
assert checks1 and checks2
254+
255+
179256
# single-rank runner
180257
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
181258
# Ensure GPUs are available and we have enough of them
@@ -195,6 +272,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
195272
func = a2a_sync_base
196273
elif arg.name.startswith("a2a_async_base"):
197274
func = a2a_async_base
275+
elif arg.name.startswith("a2a_async_twice"):
276+
func = a2a_async_twice
198277
else:
199278
func = a2a_sync_base
200279

0 commit comments

Comments
 (0)