Skip to content

Commit 7e0b0fe

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:collectives] NFC: Introduce GpuCollectives interface
PiperOrigin-RevId: 702232349
1 parent 1000ed5 commit 7e0b0fe

File tree

12 files changed

+153
-76
lines changed

12 files changed

+153
-76
lines changed

xla/backends/gpu/collectives/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@ package_group(
1616
],
1717
)
1818

19+
cc_library(
20+
name = "gpu_collectives",
21+
srcs = ["gpu_collectives.cc"],
22+
hdrs = ["gpu_collectives.h"],
23+
deps = [
24+
"//xla/core/collectives",
25+
"//xla/tsl/lib/gtl:int_type",
26+
],
27+
)
28+
1929
cc_library(
2030
name = "nccl_errors",
2131
hdrs = if_gpu_is_configured(["nccl_errors.h"]),
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/backends/gpu/collectives/gpu_collectives.h"
17+
18+
#include <cstdint>
19+
20+
namespace xla::gpu {
21+
22+
CollectiveStreamId GetCollectiveStreamId(bool is_async,
23+
AsyncStreamKind stream_kind) {
24+
// TODO(ezhulenev): This implementation does not look correct as stream IDs
25+
// are not really unique. Figure out if it's the case and fix either the code
26+
// or the documentation.
27+
int64_t stream_id = static_cast<int64_t>(stream_kind);
28+
return CollectiveStreamId(is_async ? stream_id + 1 : 0);
29+
}
30+
31+
} // namespace xla::gpu
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_
17+
#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_
18+
19+
#include <cstdint>
20+
21+
#include "xla/core/collectives/collectives.h"
22+
#include "xla/tsl/lib/gtl/int_type.h"
23+
24+
namespace xla::gpu {
25+
26+
// In XLA:GPU we use different streams for different kinds of collective
27+
// operations, and include the async stream kind into the GPU clique key.
28+
//
29+
// We carefully isolate different kinds of collectives using separate
30+
// communicators and guarantee that all collective operations have a total order
31+
// that will not create a deadlock.
32+
enum class AsyncStreamKind : int64_t {
33+
kCollective = 0, // Stream for asynchronous collective ops.
34+
kP2P0 = 1, // One Stream for P2P Send and Recv ops.
35+
kP2P1 = 2, // Another Stream for P2P Send and Recv ops.
36+
kMemCpyP2P = 3, // Stream for MemCpyP2P
37+
};
38+
39+
inline constexpr int64_t kAsyncStreamTotal =
40+
static_cast<int64_t>(AsyncStreamKind::kMemCpyP2P) + 1;
41+
42+
// Strongly-typed wrapper to represent collective stream ID.
43+
TSL_LIB_GTL_DEFINE_INT_TYPE(CollectiveStreamId, uint64_t);
44+
45+
// Assigns a unique ID to a stream for asynchronous or synchronous execution.
46+
// These IDs can be used, for example, to look up the NCCL communicator.
47+
CollectiveStreamId GetCollectiveStreamId(
48+
bool is_async, AsyncStreamKind stream_kind = AsyncStreamKind::kCollective);
49+
50+
// XLA:GPU extension of the Collectives interface with GPU-specific APIs.
51+
class GpuCollectives : public Collectives {
52+
public:
53+
};
54+
55+
} // namespace xla::gpu
56+
57+
#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_

xla/service/gpu/runtime/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ cc_library(
7272
"//xla:status_macros",
7373
"//xla:types",
7474
"//xla:util",
75+
"//xla/backends/gpu/collectives:gpu_collectives",
7576
"//xla/core/collectives:communicator",
7677
"//xla/ffi:call_frame",
7778
"//xla/ffi:ffi_api",
@@ -215,6 +216,7 @@ cc_library(
215216
"//xla:shape_util",
216217
"//xla:util",
217218
"//xla:xla_data_proto_cc",
219+
"//xla/backends/gpu/collectives:gpu_collectives",
218220
"//xla/backends/gpu/collectives:nccl_communicator",
219221
"//xla/core/collectives:clique_id",
220222
"//xla/core/collectives:communicator",
@@ -258,6 +260,7 @@ cc_library(
258260
":nccl_clique_key",
259261
"//xla:shape_util",
260262
"//xla:xla_data_proto_cc",
263+
"//xla/backends/gpu/collectives:gpu_collectives",
261264
"//xla/core/collectives:clique_id",
262265
"//xla/core/collectives:communicator",
263266
"//xla/core/collectives:rank_id",
@@ -318,6 +321,7 @@ cc_library(
318321
hdrs = ["nccl_clique_key.h"],
319322
compatible_with = get_compatible_with_portable(),
320323
deps = [
324+
"//xla/backends/gpu/collectives:gpu_collectives",
321325
"//xla/core/collectives:clique_id",
322326
"//xla/core/collectives:clique_key",
323327
"//xla/core/collectives:rank_id",
@@ -342,6 +346,7 @@ xla_cc_test(
342346
srcs = ["nccl_clique_key_test.cc"],
343347
deps = [
344348
":nccl_clique_key",
349+
"//xla/backends/gpu/collectives:gpu_collectives",
345350
"//xla/core/collectives:clique_id",
346351
"//xla/service:global_device_id",
347352
"@com_google_absl//absl/container:btree",
@@ -882,6 +887,7 @@ cc_library(
882887
":thunk",
883888
"//xla:shape_util",
884889
"//xla:status_macros",
890+
"//xla/backends/gpu/collectives:gpu_collectives",
885891
"//xla/core/collectives:communicator",
886892
"//xla/hlo/ir:hlo",
887893
"//xla/service:collective_ops_utils",
@@ -960,6 +966,7 @@ cc_library(
960966
"//xla:shape_util",
961967
"//xla:util",
962968
"//xla:xla_data_proto_cc",
969+
"//xla/backends/gpu/collectives:gpu_collectives",
963970
"//xla/core/collectives:communicator",
964971
"//xla/core/collectives:rank_id",
965972
"//xla/hlo/ir:hlo",

xla/service/gpu/runtime/command_buffer_cmd.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License.
3737
#include "absl/strings/string_view.h"
3838
#include "absl/synchronization/mutex.h"
3939
#include "absl/types/span.h"
40+
#include "xla/backends/gpu/collectives/gpu_collectives.h"
4041
#include "xla/ffi/api/c_api.h"
4142
#include "xla/hlo/ir/hlo_computation.h"
4243
#include "xla/service/buffer_assignment.h"
@@ -1013,8 +1014,8 @@ class CollectiveCmd : public CommandBufferCmd {
10131014
return async_from_stream_id_ != execution_stream_id();
10141015
}
10151016

1016-
NcclStreamId nccl_stream_id() {
1017-
return xla::gpu::GetStreamId(IsAsync(), GetAsyncStreamKind());
1017+
CollectiveStreamId nccl_stream_id() {
1018+
return xla::gpu::GetCollectiveStreamId(IsAsync(), GetAsyncStreamKind());
10181019
}
10191020

10201021
ExecutionStreamId async_from_stream_id() const {

xla/service/gpu/runtime/nccl_all_to_all_thunk.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "absl/status/status.h"
2828
#include "absl/strings/substitute.h"
2929
#include "absl/synchronization/mutex.h"
30+
#include "xla/backends/gpu/collectives/gpu_collectives.h"
3031
#include "xla/core/collectives/communicator.h"
3132
#include "xla/hlo/ir/hlo_instruction.h"
3233
#include "xla/hlo/ir/hlo_instructions.h"
@@ -105,7 +106,7 @@ absl::Status NcclAllToAllStartThunk::Initialize(
105106
VLOG(5) << "Local device count: " << device_count_;
106107

107108
if (is_local() && p2p_memcpy_enabled_) {
108-
const NcclStreamId stream_id = nccl_stream_id();
109+
const CollectiveStreamId stream_id = nccl_stream_id();
109110
AsyncStreamKind stream_kind = GetAsyncStreamKind();
110111
TF_ASSIGN_OR_RETURN(
111112
CommunicatorHandle comm_handle,
@@ -136,7 +137,7 @@ absl::Status NcclAllToAllStartThunk::Initialize(
136137

137138
absl::Status NcclAllToAllStartThunk::Cleanup(const CleanupParams& params) {
138139
if (p2p_memcpy_enabled_) {
139-
const NcclStreamId stream_id = nccl_stream_id();
140+
const CollectiveStreamId stream_id = nccl_stream_id();
140141
AsyncStreamKind stream_kind = GetAsyncStreamKind();
141142
TF_ASSIGN_OR_RETURN(
142143
CommunicatorHandle comm_handle,

xla/service/gpu/runtime/nccl_api.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ limitations under the License.
2525
#include "absl/status/status.h"
2626
#include "absl/status/statusor.h"
2727
#include "absl/types/span.h"
28+
#include "xla/backends/gpu/collectives/gpu_collectives.h"
29+
#include "xla/core/collectives/clique_id.h"
2830
#include "xla/core/collectives/communicator.h"
2931
#include "xla/core/collectives/rank_id.h"
3032
#include "xla/service/collective_ops_utils.h"
31-
#include "xla/service/gpu/runtime/nccl_clique_key.h"
3233
#include "xla/shape_util.h"
3334
#include "xla/stream_executor/device_memory.h"
3435
#include "xla/stream_executor/device_memory_allocator.h"
@@ -47,7 +48,7 @@ namespace xla::gpu {
4748
// NCCL library so that no other parts of XLA should include nccl.h header
4849
// directly (or indirectly).
4950

50-
class NcclApi {
51+
class NcclApi : public GpuCollectives {
5152
public:
5253
virtual ~NcclApi() = default;
5354

xla/service/gpu/runtime/nccl_clique_key.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "absl/strings/str_format.h"
2525
#include "absl/strings/str_join.h"
2626
#include "absl/types/span.h"
27+
#include "xla/backends/gpu/collectives/gpu_collectives.h"
2728
#include "xla/core/collectives/clique_key.h"
2829
#include "xla/service/global_device_id.h"
2930
#include "tsl/platform/casts.h"
@@ -36,7 +37,7 @@ namespace xla::gpu {
3637
//===----------------------------------------------------------------------===//
3738

3839
NcclCliqueKey::NcclCliqueKey(
39-
std::vector<GlobalDeviceId> devices, NcclStreamId stream_id,
40+
std::vector<GlobalDeviceId> devices, CollectiveStreamId stream_id,
4041
AsyncStreamKind stream_kind,
4142
std::vector<std::vector<GlobalDeviceId>> participant_groups)
4243
: CliqueKey(std::move(devices)),
@@ -56,7 +57,7 @@ NcclCliqueKey::NcclCliqueKey(
5657
absl::c_sort(participant_groups_, compare_groups);
5758
}
5859

59-
NcclStreamId NcclCliqueKey::stream_id() const { return stream_id_; }
60+
CollectiveStreamId NcclCliqueKey::stream_id() const { return stream_id_; }
6061

6162
bool NcclCliqueKey::IsSubsetOf(const CliqueKey& other) const {
6263
auto* other_nccl = tsl::down_cast<const NcclCliqueKey*>(&other);

xla/service/gpu/runtime/nccl_clique_key.h

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,49 +23,13 @@ limitations under the License.
2323

2424
#include "absl/hash/hash.h"
2525
#include "absl/status/statusor.h"
26+
#include "xla/backends/gpu/collectives/gpu_collectives.h"
2627
#include "xla/core/collectives/clique_id.h"
2728
#include "xla/core/collectives/clique_key.h"
2829
#include "xla/service/global_device_id.h"
29-
#include "xla/tsl/lib/gtl/int_type.h"
3030

3131
namespace xla::gpu {
3232

33-
TSL_LIB_GTL_DEFINE_INT_TYPE(NcclStreamId, uint64_t);
34-
35-
// A standalone library without any dependencies on NCCL that allows us to
36-
// include this header in all of XLA without worrying about NCCL availability.
37-
38-
//===----------------------------------------------------------------------===//
39-
// AsyncStreamKind
40-
//===----------------------------------------------------------------------===//
41-
42-
// We include a stream kind into the NCCL clique key because in XLA we do not
43-
// share communicators for collective operations of different kind (CUDA-graph
44-
// launched, async collectives, sync collectives) as it can lead to dead locks.
45-
//
46-
// We carefully isolate different kinds of collectives using separate
47-
// communicators and guarantee that all collective operations have a total order
48-
// that will not create a deadlock.
49-
//
50-
// See more details in `nccl_clique` library.
51-
52-
enum class AsyncStreamKind : int64_t {
53-
kCollective = 0, // Stream for asynchronous collective ops.
54-
kP2P0 = 1, // One Stream for P2P Send and Recv ops.
55-
kP2P1 = 2, // Another Stream for P2P Send and Recv ops.
56-
kMemCpyP2P = 3, // Stream for MemCpyP2P
57-
};
58-
59-
constexpr static int64_t kAsyncStreamTotal =
60-
static_cast<int64_t>(AsyncStreamKind::kMemCpyP2P) + 1;
61-
62-
// Assigns a unique ID to a stream for asynchronous or synchronous execution.
63-
// These IDs can be used, for example, to look up the NCCL communicator.
64-
inline NcclStreamId GetStreamId(
65-
bool is_async, AsyncStreamKind stream_kind = AsyncStreamKind::kCollective) {
66-
return NcclStreamId(is_async ? static_cast<uint64_t>(stream_kind) + 1 : 0);
67-
}
68-
6933
//===----------------------------------------------------------------------===//
7034
// NcclCliqueKey
7135
//===----------------------------------------------------------------------===//
@@ -79,11 +43,11 @@ class NcclCliqueKey : public CliqueKey {
7943
public:
8044
explicit NcclCliqueKey(
8145
std::vector<GlobalDeviceId> devices,
82-
NcclStreamId stream_id = NcclStreamId(0),
46+
CollectiveStreamId stream_id = CollectiveStreamId(0),
8347
AsyncStreamKind stream_kind = AsyncStreamKind::kCollective,
8448
std::vector<std::vector<GlobalDeviceId>> participant_groups = {});
8549

86-
NcclStreamId stream_id() const;
50+
CollectiveStreamId stream_id() const;
8751

8852
// Returns true if this clique is a subset of `other`: both cliques have the
8953
// same `stream_id` and all clique devices are part of `other` clique.
@@ -103,7 +67,7 @@ class NcclCliqueKey : public CliqueKey {
10367
private:
10468
void HashValue(absl::HashState state) const final;
10569

106-
NcclStreamId stream_id_;
70+
CollectiveStreamId stream_id_;
10771
AsyncStreamKind stream_kind_;
10872

10973
// The full list of groups across all devices which this clique is a part of.

0 commit comments

Comments
 (0)