Skip to content

Commit f0ca2e2

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:collectives] NFC: Move NcclCliqueKey to GpuCliqueKey
PiperOrigin-RevId: 702261245
1 parent 22d4102 commit f0ca2e2

File tree

8 files changed

+205
-191
lines changed

8 files changed

+205
-191
lines changed

xla/backends/gpu/collectives/BUILD

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
2+
load("//xla:xla.bzl", "xla_cc_test")
23
load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured")
34
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
45
load("//xla/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
@@ -16,13 +17,46 @@ package_group(
1617
],
1718
)
1819

20+
cc_library(
21+
name = "gpu_clique_key",
22+
srcs = ["gpu_clique_key.cc"],
23+
hdrs = ["gpu_clique_key.h"],
24+
deps = [
25+
"//xla/core/collectives",
26+
"//xla/core/collectives:clique_id",
27+
"//xla/core/collectives:clique_key",
28+
"//xla/service:global_device_id",
29+
"//xla/tsl/lib/gtl:int_type",
30+
"@com_google_absl//absl/algorithm:container",
31+
"@com_google_absl//absl/hash",
32+
"@com_google_absl//absl/strings",
33+
"@com_google_absl//absl/strings:str_format",
34+
"@com_google_absl//absl/types:span",
35+
"@tsl//tsl/platform:casts",
36+
"@tsl//tsl/platform:logging",
37+
],
38+
)
39+
40+
xla_cc_test(
41+
name = "gpu_clique_key_test",
42+
srcs = ["gpu_clique_key_test.cc"],
43+
deps = [
44+
":gpu_clique_key",
45+
"//xla/core/collectives:clique_id",
46+
"//xla/service:global_device_id",
47+
"@com_google_absl//absl/container:btree",
48+
"@com_google_absl//absl/status",
49+
"@tsl//tsl/platform:status_matchers",
50+
"@tsl//tsl/platform:test",
51+
"@tsl//tsl/platform:test_main",
52+
],
53+
)
54+
1955
cc_library(
2056
name = "gpu_collectives",
21-
srcs = ["gpu_collectives.cc"],
2257
hdrs = ["gpu_collectives.h"],
2358
deps = [
2459
"//xla/core/collectives",
25-
"//xla/tsl/lib/gtl:int_type",
2660
],
2761
)
2862

xla/service/gpu/runtime/nccl_clique_key.cc renamed to xla/backends/gpu/collectives/gpu_clique_key.cc

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "xla/service/gpu/runtime/nccl_clique_key.h"
16+
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
1717

18+
#include <cstdint>
1819
#include <string>
1920
#include <utility>
2021
#include <vector>
@@ -24,19 +25,23 @@ limitations under the License.
2425
#include "absl/strings/str_format.h"
2526
#include "absl/strings/str_join.h"
2627
#include "absl/types/span.h"
27-
#include "xla/backends/gpu/collectives/gpu_collectives.h"
2828
#include "xla/core/collectives/clique_key.h"
2929
#include "xla/service/global_device_id.h"
3030
#include "tsl/platform/casts.h"
3131
#include "tsl/platform/logging.h"
3232

3333
namespace xla::gpu {
3434

35-
//===----------------------------------------------------------------------===//
36-
// NcclCliqueKey
37-
//===----------------------------------------------------------------------===//
35+
CollectiveStreamId GetCollectiveStreamId(bool is_async,
36+
AsyncStreamKind stream_kind) {
37+
// TODO(ezhulenev): This implementation does not look correct as stream IDs
38+
// are not really unique. Figure out if it's the case and fix either the code
39+
// or the documentation.
40+
int64_t stream_id = static_cast<int64_t>(stream_kind);
41+
return CollectiveStreamId(is_async ? stream_id + 1 : 0);
42+
}
3843

39-
NcclCliqueKey::NcclCliqueKey(
44+
GpuCliqueKey::GpuCliqueKey(
4045
std::vector<GlobalDeviceId> devices, CollectiveStreamId stream_id,
4146
AsyncStreamKind stream_kind,
4247
std::vector<std::vector<GlobalDeviceId>> participant_groups)
@@ -57,10 +62,10 @@ NcclCliqueKey::NcclCliqueKey(
5762
absl::c_sort(participant_groups_, compare_groups);
5863
}
5964

60-
CollectiveStreamId NcclCliqueKey::stream_id() const { return stream_id_; }
65+
CollectiveStreamId GpuCliqueKey::stream_id() const { return stream_id_; }
6166

62-
bool NcclCliqueKey::IsSubsetOf(const CliqueKey& other) const {
63-
auto* other_nccl = tsl::down_cast<const NcclCliqueKey*>(&other);
67+
bool GpuCliqueKey::IsSubsetOf(const CliqueKey& other) const {
68+
auto* other_nccl = tsl::down_cast<const GpuCliqueKey*>(&other);
6469
if (other_nccl == nullptr) return false;
6570

6671
return stream_id_ == other_nccl->stream_id_ &&
@@ -69,7 +74,7 @@ bool NcclCliqueKey::IsSubsetOf(const CliqueKey& other) const {
6974
});
7075
}
7176

72-
std::string NcclCliqueKey::ToString() const {
77+
std::string GpuCliqueKey::ToString() const {
7378
std::string group_string = "";
7479
if (!participant_groups_.empty()) {
7580
std::vector<std::string> values;
@@ -84,17 +89,17 @@ std::string NcclCliqueKey::ToString() const {
8489
group_string);
8590
}
8691

87-
void NcclCliqueKey::HashValue(absl::HashState state) const {
92+
void GpuCliqueKey::HashValue(absl::HashState state) const {
8893
absl::HashState::combine(std::move(state), devices(), stream_id_,
8994
participant_groups_);
9095
}
9196

92-
bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) {
97+
bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b) {
9398
return a.devices() == b.devices() && a.stream_id_ == b.stream_id_ &&
9499
a.participant_groups_ == b.participant_groups_;
95100
}
96101

97-
bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) {
102+
bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b) {
98103
if (a.devices().size() < b.devices().size()) return true;
99104
if (b.devices().size() < a.devices().size()) return false;
100105

@@ -104,7 +109,7 @@ bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) {
104109
return a.stream_id_.value() < b.stream_id_.value();
105110
}
106111

107-
bool operator>(const NcclCliqueKey& a, const NcclCliqueKey& b) {
112+
bool operator>(const GpuCliqueKey& a, const GpuCliqueKey& b) {
108113
if (a.devices().size() > b.devices().size()) return true;
109114
if (b.devices().size() > a.devices().size()) return false;
110115

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_
17+
#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_
18+
19+
#include <cstdint>
20+
#include <string>
21+
#include <vector>
22+
23+
#include "absl/hash/hash.h"
24+
#include "xla/core/collectives/clique_key.h"
25+
#include "xla/service/global_device_id.h"
26+
#include "xla/tsl/lib/gtl/int_type.h"
27+
28+
namespace xla::gpu {
29+
30+
// In XLA:GPU we use different streams for different kinds of collective
31+
// operations, and include the async stream kind into the GPU clique key.
32+
//
33+
// We carefully isolate different kinds of collectives using separate
34+
// communicators and guarantee that all collective operations have a total order
35+
// that will not create a deadlock.
36+
enum class AsyncStreamKind : int64_t {
37+
kCollective = 0, // Stream for asynchronous collective ops.
38+
kP2P0 = 1, // One Stream for P2P Send and Recv ops.
39+
kP2P1 = 2, // Another Stream for P2P Send and Recv ops.
40+
kMemCpyP2P = 3, // Stream for MemCpyP2P
41+
};
42+
43+
inline constexpr int64_t kAsyncStreamTotal =
44+
static_cast<int64_t>(AsyncStreamKind::kMemCpyP2P) + 1;
45+
46+
// Strongly-typed wrapper to represent collective stream ID.
47+
TSL_LIB_GTL_DEFINE_INT_TYPE(CollectiveStreamId, uint64_t);
48+
49+
// Assigns a unique ID to a stream for asynchronous or synchronous execution.
50+
// These IDs can be used, for example, to look up the NCCL communicator.
51+
CollectiveStreamId GetCollectiveStreamId(
52+
bool is_async, AsyncStreamKind stream_kind = AsyncStreamKind::kCollective);
53+
54+
// Clique key for identifying a particular collectives clique on a GPU backend.
55+
class GpuCliqueKey : public CliqueKey {
56+
public:
57+
explicit GpuCliqueKey(
58+
std::vector<GlobalDeviceId> devices,
59+
CollectiveStreamId stream_id = CollectiveStreamId(0),
60+
AsyncStreamKind stream_kind = AsyncStreamKind::kCollective,
61+
std::vector<std::vector<GlobalDeviceId>> participant_groups = {});
62+
63+
CollectiveStreamId stream_id() const;
64+
65+
// Returns true if this clique is a subset of `other`: both cliques have the
66+
// same `stream_id` and all clique devices are part of `other` clique.
67+
bool IsSubsetOf(const CliqueKey& other) const final;
68+
69+
// Returns the stream kind for this clique key, stream kind will be used to
70+
// specify what configuration to pass for each type of operation.
71+
AsyncStreamKind stream_kind() const { return stream_kind_; }
72+
73+
std::string ToString() const final;
74+
75+
// GPU clique keys have a total order on which we rely on for acquiring
76+
// cliques in the same order across all participating devices.
77+
friend bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b);
78+
friend bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b);
79+
friend bool operator>(const GpuCliqueKey& a, const GpuCliqueKey& b);
80+
81+
private:
82+
void HashValue(absl::HashState state) const final;
83+
84+
CollectiveStreamId stream_id_;
85+
AsyncStreamKind stream_kind_;
86+
87+
// The full list of groups across all devices which this clique is a part of.
88+
//
89+
// When GPU communicator splitting is enabled, this is used to distinguish
90+
// which cliques can be reused from the cache or must be split in order to
91+
// prevent a deadlock situation.
92+
//
93+
// For example, imagine we have a communicator with devices = [0,1] and
94+
// groups = [0, 1] Later on, we may want to create communicators [0, 1] and
95+
// [2, 3] by splitting [0, 1, 2, 3] If ranks 0 and 1 reuse the existing
96+
// [0, 1] clique but ranks 2 and 3 initiate a split, there will be a deadlock
97+
// since ranks 2, 3 and will be waiting forever for 0, 1 to join the split.
98+
//
99+
// Having the participating groups as part of the cache key will prevent such
100+
// situations
101+
std::vector<std::vector<GlobalDeviceId>> participant_groups_;
102+
};
103+
104+
bool operator==(const GpuCliqueKey& a, const GpuCliqueKey& b);
105+
bool operator<(const GpuCliqueKey& a, const GpuCliqueKey& b);
106+
107+
} // namespace xla::gpu
108+
109+
#endif // XLA_BACKENDS_GPU_COLLECTIVES_GPU_CLIQUE_KEY_H_

0 commit comments

Comments
 (0)