@@ -23,49 +23,13 @@ limitations under the License.
2323
2424#include " absl/hash/hash.h"
2525#include " absl/status/statusor.h"
26+ #include " xla/backends/gpu/collectives/gpu_collectives.h"
2627#include " xla/core/collectives/clique_id.h"
2728#include " xla/core/collectives/clique_key.h"
2829#include " xla/service/global_device_id.h"
29- #include " xla/tsl/lib/gtl/int_type.h"
3030
3131namespace xla ::gpu {
3232
33- TSL_LIB_GTL_DEFINE_INT_TYPE (NcclStreamId, uint64_t );
34-
35- // A standalone library without any dependencies on NCCL that allows us to
36- // include this header in all of XLA without worrying about NCCL availability.
37-
38- // ===----------------------------------------------------------------------===//
39- // AsyncStreamKind
40- // ===----------------------------------------------------------------------===//
41-
42- // We include a stream kind into the NCCL clique key because in XLA we do not
43- // share communicators for collective operations of different kind (CUDA-graph
44- // launched, async collectives, sync collectives) as it can lead to dead locks.
45- //
46- // We carefully isolate different kinds of collectives using separate
47- // communicators and guarantee that all collective operations have a total order
48- // that will not create a deadlock.
49- //
50- // See more details in `nccl_clique` library.
51-
52- enum class AsyncStreamKind : int64_t {
53- kCollective = 0 , // Stream for asynchronous collective ops.
54- kP2P0 = 1 , // One Stream for P2P Send and Recv ops.
55- kP2P1 = 2 , // Another Stream for P2P Send and Recv ops.
56- kMemCpyP2P = 3 , // Stream for MemCpyP2P
57- };
58-
59- constexpr static int64_t kAsyncStreamTotal =
60- static_cast <int64_t >(AsyncStreamKind::kMemCpyP2P ) + 1 ;
61-
62- // Assigns a unique ID to a stream for asynchronous or synchronous execution.
63- // These IDs can be used, for example, to look up the NCCL communicator.
64- inline NcclStreamId GetStreamId (
65- bool is_async, AsyncStreamKind stream_kind = AsyncStreamKind::kCollective ) {
66- return NcclStreamId (is_async ? static_cast <uint64_t >(stream_kind) + 1 : 0 );
67- }
68-
6933// ===----------------------------------------------------------------------===//
7034// NcclCliqueKey
7135// ===----------------------------------------------------------------------===//
@@ -79,11 +43,11 @@ class NcclCliqueKey : public CliqueKey {
7943 public:
8044 explicit NcclCliqueKey (
8145 std::vector<GlobalDeviceId> devices,
82- NcclStreamId stream_id = NcclStreamId (0 ),
46+ CollectiveStreamId stream_id = CollectiveStreamId (0 ),
8347 AsyncStreamKind stream_kind = AsyncStreamKind::kCollective,
8448 std::vector<std::vector<GlobalDeviceId>> participant_groups = {});
8549
86- NcclStreamId stream_id () const ;
50+ CollectiveStreamId stream_id () const ;
8751
8852 // Returns true if this clique is a subset of `other`: both cliques have the
8953 // same `stream_id` and all clique devices are part of `other` clique.
@@ -103,7 +67,7 @@ class NcclCliqueKey : public CliqueKey {
10367 private:
10468 void HashValue (absl::HashState state) const final ;
10569
106- NcclStreamId stream_id_;
70+ CollectiveStreamId stream_id_;
10771 AsyncStreamKind stream_kind_;
10872
10973 // The full list of groups across all devices which this clique is a part of.
0 commit comments