You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
explore using a side stream for two data-dependent all_to_all_single comms (#3440)
Summary:
Pull Request resolved: #3440
# context
* table-wise-row-wise (TWRW) sharding takes the advantage of high bandwidth intra-node comms for the data-intensive row-wise sharded embedding table pooling.
* it uses [two output dist components](https://github.com/meta-pytorch/torchrec/blob/release/v1.3.0/torchrec/distributed/sharding/twrw_sharding.py#L479-L490): intra-node dist and cross-node dist. The cross-node dist relies on the data/result from intra-node dist.
* This data dependency actually creates a blocking situation on the main cuda (compute) stream, as shown below (nccl:_reduce_scatter for the intra-node dist, nccl:_all_to_all for the cross-node dist)
{F1982557282}
# experiment
* the correct approach is to use a side stream to process the data-dependent comms
* without side stream: [trace](https://drive.google.com/file/d/1lpa-NrBD0IWcpskdN1Lwiu0XcSTe01bW/view?usp=sharing)
the first comms is blocking the main stream execution
{F1982557422}
* with side stream: [trace](https://drive.google.com/file/d/1FqNpq4yMx9H6vL47S8KX5dvk2PJv_QGa/view?usp=sharing)
both comms are non-blocking on the main stream
{F1982557381}
Reviewed By: spmex
Differential Revision: D82002643
fbshipit-source-id: 00ee3e7b20f4ed0b799b3c8a49a3a5f7566f87c1
0 commit comments