Skip to content

Commit 50098e0

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:collectives] NFC: Move CliqueIsCallback alias to gpu_executable_run_options
PiperOrigin-RevId: 702238215
1 parent 42a164f commit 50098e0

File tree

15 files changed

+83
-70
lines changed

15 files changed

+83
-70
lines changed

xla/core/collectives/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
load("//xla/tsl:tsl.bzl", "internal_visibility")
12
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
23

34
package(
45
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
5-
default_visibility = [":friends"],
6+
default_visibility = internal_visibility([":friends"]),
67
licenses = ["notice"],
78
)
89

xla/pjrt/gpu/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ cc_library(
218218
hdrs = ["nccl_id_store.h"],
219219
deps = [
220220
"//xla:status_macros",
221+
"//xla:util",
221222
"//xla/core/collectives:clique_id",
223+
"//xla/core/collectives:clique_key",
222224
"//xla/pjrt/distributed:key_value_store_interface",
223225
"//xla/service:global_device_id",
224226
"//xla/service/gpu/runtime:nccl_api",
@@ -228,6 +230,7 @@ cc_library(
228230
"@com_google_absl//absl/status:statusor",
229231
"@com_google_absl//absl/synchronization",
230232
"@com_google_absl//absl/time",
233+
"@tsl//tsl/platform:casts",
231234
"@tsl//tsl/platform:errors",
232235
"@tsl//tsl/platform:statusor",
233236
],

xla/pjrt/gpu/nccl_id_store.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,45 @@ limitations under the License.
2222
#include "absl/synchronization/mutex.h"
2323
#include "absl/time/time.h"
2424
#include "xla/core/collectives/clique_id.h"
25+
#include "xla/core/collectives/clique_key.h"
2526
#include "xla/service/gpu/runtime/nccl_api.h"
2627
#include "xla/service/gpu/runtime/nccl_clique_key.h"
2728
#include "xla/status_macros.h"
29+
#include "xla/util.h"
30+
#include "tsl/platform/casts.h"
2831
#include "tsl/platform/errors.h"
2932
#include "tsl/platform/statusor.h"
3033

3134
namespace xla {
3235

33-
absl::StatusOr<CliqueId> NcclIdStore::GetNcclUniqueId(
34-
const gpu::NcclCliqueKey& key) {
36+
absl::StatusOr<CliqueId> NcclIdStore::GetNcclUniqueId(const CliqueKey& key) {
37+
auto* gpu_key = tsl::down_cast<const gpu::NcclCliqueKey*>(&key);
38+
if (gpu_key == nullptr) {
39+
return InvalidArgument("Expected GPU clique key");
40+
}
41+
3542
// The caller must ensure that threads calling this method concurrently have
3643
// unique keys, otherwise the global key-value store may hold the wrong value.
3744
{
3845
absl::MutexLock lock(&mu_);
39-
auto it = cache_.find(key);
46+
auto it = cache_.find(*gpu_key);
4047
if (it != cache_.end()) {
4148
return it->second;
4249
}
4350
}
4451
CliqueId clique_id;
45-
int primary_node_id = device_to_node_.at(key.devices()[0]);
52+
int primary_node_id = device_to_node_.at(gpu_key->devices()[0]);
4653
if (node_id_ == primary_node_id) {
4754
TF_ASSIGN_OR_RETURN(clique_id, gpu::NcclApi::Default()->GetUniqueId());
48-
TF_RETURN_IF_ERROR(kv_store_->Set(key.ToString(), clique_id.ToString()));
55+
TF_RETURN_IF_ERROR(
56+
kv_store_->Set(gpu_key->ToString(), clique_id.ToString()));
4957
} else {
5058
TF_ASSIGN_OR_RETURN(std::string id_str,
51-
kv_store_->Get(key.ToString(), absl::Minutes(10)));
59+
kv_store_->Get(gpu_key->ToString(), absl::Minutes(10)));
5260
clique_id = CliqueId(id_str);
5361
}
5462
absl::MutexLock lock(&mu_);
55-
auto result = cache_.emplace(key, std::move(clique_id));
63+
auto result = cache_.emplace(*gpu_key, std::move(clique_id));
5664
TF_RET_CHECK(result.second) << "Unique ID already in cache.";
5765
return result.first->second;
5866
}

xla/pjrt/gpu/nccl_id_store.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "absl/status/statusor.h"
2525
#include "absl/synchronization/mutex.h"
2626
#include "xla/core/collectives/clique_id.h"
27+
#include "xla/core/collectives/clique_key.h"
2728
#include "xla/pjrt/distributed/key_value_store_interface.h"
2829
#include "xla/service/global_device_id.h"
2930
#include "xla/service/gpu/runtime/nccl_clique_key.h"
@@ -42,7 +43,7 @@ class NcclIdStore {
4243
device_to_node_(std::move(device_to_node)),
4344
kv_store_(std::move(kv_store)) {}
4445

45-
absl::StatusOr<CliqueId> GetNcclUniqueId(const gpu::NcclCliqueKey& key);
46+
absl::StatusOr<CliqueId> GetNcclUniqueId(const CliqueKey& key);
4647

4748
private:
4849
const int node_id_;

xla/pjrt/gpu/se_gpu_pjrt_client.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,8 +1175,8 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
11751175
if (num_nodes > 1) {
11761176
auto nccl_id_store = std::make_shared<NcclIdStore>(node_id, device_to_node,
11771177
std::move(kv_store));
1178-
gpu_executable_run_options->set_nccl_clique_id_callback(
1179-
[nccl_id_store](const gpu::NcclCliqueKey& key) {
1178+
gpu_executable_run_options->set_clique_id_callback(
1179+
[nccl_id_store](const CliqueKey& key) {
11801180
return nccl_id_store->GetNcclUniqueId(key);
11811181
});
11821182
}
@@ -1300,7 +1300,7 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
13001300

13011301
auto gpu_run_options = std::make_unique<gpu::GpuExecutableRunOptions>();
13021302
if (options.enable_mock_nccl) {
1303-
gpu_run_options->set_enable_mock_nccl_collectives();
1303+
gpu_run_options->set_enable_mock_collectives();
13041304
}
13051305

13061306
static const bool xla_gpu_require_exclusive_lock =

xla/service/gpu/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,12 @@ cc_library(
9797
visibility = ["//visibility:public"],
9898
deps = [
9999
"//xla:executable_run_options",
100+
"//xla/core/collectives:clique_id",
101+
"//xla/core/collectives:clique_key",
100102
"//xla/service:global_device_id",
101103
"//xla/service/gpu/runtime:nccl_clique_key",
104+
"@com_google_absl//absl/status",
105+
"@com_google_absl//absl/status:statusor",
102106
],
103107
)
104108

xla/service/gpu/gpu_executable.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ class ResourceRequests : public Thunk::ResourceRequests {
259259

260260
bool is_local = r.key.devices().size() == r.num_local_participants;
261261
TF_ASSIGN_OR_RETURN(
262-
const NcclCliqueIdCallback* clique_id_callback,
263-
GetNcclCliqueIdCallback(params.nccl_clique_id_callback, is_local));
262+
const CliqueIdCallback* clique_id_callback,
263+
GetCliqueIdCallback(params.nccl_clique_id_callback, is_local));
264264

265265
int64_t max_channels = r.key.stream_kind() == AsyncStreamKind::kCollective
266266
? params.collective_max_nchannels
@@ -348,7 +348,7 @@ absl::Status ExecuteThunks(
348348
run_options->run_options().gpu_executable_run_options()
349349
? run_options->run_options()
350350
.gpu_executable_run_options()
351-
->enable_mock_nccl_collectives()
351+
->enable_mock_collectives()
352352
: false;
353353

354354
int64_t collective_max_nchannels =

xla/service/gpu/gpu_executable_run_options.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ limitations under the License.
2121

2222
#include "xla/executable_run_options.h"
2323
#include "xla/service/global_device_id.h"
24-
#include "xla/service/gpu/runtime/nccl_clique_key.h"
2524

26-
namespace xla {
27-
namespace gpu {
25+
namespace xla::gpu {
2826

2927
GpuExecutableRunOptions& GpuExecutableRunOptions::set_gpu_global_device_ids(
3028
std::optional<std::map<int, GlobalDeviceId>> gpu_global_device_ids) {
@@ -37,16 +35,14 @@ GpuExecutableRunOptions::gpu_global_device_ids() const {
3735
return gpu_global_device_ids_;
3836
}
3937

40-
GpuExecutableRunOptions& GpuExecutableRunOptions::set_nccl_clique_id_callback(
41-
NcclCliqueIdCallback nccl_clique_id_callback) {
42-
nccl_clique_id_callback_ = std::move(nccl_clique_id_callback);
38+
GpuExecutableRunOptions& GpuExecutableRunOptions::set_clique_id_callback(
39+
CliqueIdCallback clique_id_callback) {
40+
clique_id_callback_ = std::move(clique_id_callback);
4341
return *this;
4442
}
4543

46-
const NcclCliqueIdCallback& GpuExecutableRunOptions::nccl_clique_id_callback()
47-
const {
48-
return nccl_clique_id_callback_;
44+
const CliqueIdCallback& GpuExecutableRunOptions::clique_id_callback() const {
45+
return clique_id_callback_;
4946
}
5047

51-
} // namespace gpu
52-
} // namespace xla
48+
} // namespace xla::gpu

xla/service/gpu/gpu_executable_run_options.h

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,21 @@ limitations under the License.
1616
#ifndef XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
1717
#define XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
1818

19+
#include <functional>
1920
#include <map>
2021
#include <optional>
2122

23+
#include "absl/status/statusor.h"
24+
#include "xla/core/collectives/clique_id.h"
25+
#include "xla/core/collectives/clique_key.h"
2226
#include "xla/executable_run_options.h"
2327
#include "xla/service/global_device_id.h"
24-
#include "xla/service/gpu/runtime/nccl_clique_key.h"
2528

26-
namespace xla {
27-
namespace gpu {
29+
namespace xla::gpu {
30+
31+
// A callback to get a unique clique id.
32+
using CliqueIdCallback = // NOLINT
33+
std::function<absl::StatusOr<CliqueId>(const CliqueKey&)>;
2834

2935
// GPU-specific executable options.
3036
// We keep these separate from ExecutableRunOptions to avoid adding
@@ -40,11 +46,10 @@ class GpuExecutableRunOptions {
4046
const std::optional<std::map<int, GlobalDeviceId>>& gpu_global_device_ids()
4147
const;
4248

43-
// Callback that returns a ncclUniqueId encoded as a string for a group of
44-
// communicating GPU devices. Used only on NVidia GPUs.
45-
GpuExecutableRunOptions& set_nccl_clique_id_callback(
46-
NcclCliqueIdCallback nccl_clique_id_callback);
47-
const NcclCliqueIdCallback& nccl_clique_id_callback() const;
49+
// Callback that returns a unique clieque id for a given clique key.
50+
GpuExecutableRunOptions& set_clique_id_callback(
51+
CliqueIdCallback clique_id_callback);
52+
const CliqueIdCallback& clique_id_callback() const;
4853

4954
// Whether the run requires an exclusive lock on the GPU.
5055
bool requires_exclusive_lock_on_gpu() const {
@@ -57,24 +62,21 @@ class GpuExecutableRunOptions {
5762
return *this;
5863
}
5964

60-
bool enable_mock_nccl_collectives() const {
61-
return enable_mock_nccl_collectives_;
62-
}
65+
bool enable_mock_collectives() const { return enable_mock_collectives_; }
6366

6467
// Enables mocking nccl collective operations on the GPU.
65-
GpuExecutableRunOptions& set_enable_mock_nccl_collectives() {
66-
enable_mock_nccl_collectives_ = true;
68+
GpuExecutableRunOptions& set_enable_mock_collectives() {
69+
enable_mock_collectives_ = true;
6770
return *this;
6871
}
6972

7073
private:
7174
bool requires_exclusive_lock_on_gpu_ = false;
72-
bool enable_mock_nccl_collectives_ = false;
75+
bool enable_mock_collectives_ = false;
7376
std::optional<std::map<int, GlobalDeviceId>> gpu_global_device_ids_;
74-
NcclCliqueIdCallback nccl_clique_id_callback_;
77+
CliqueIdCallback clique_id_callback_;
7578
};
7679

77-
} // namespace gpu
78-
} // namespace xla
80+
} // namespace xla::gpu
7981

8082
#endif // XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_

xla/service/gpu/runtime/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,13 @@ cc_library(
289289
"//xla:types",
290290
"//xla:util",
291291
"//xla/core/collectives:clique_id",
292+
"//xla/core/collectives:clique_key",
292293
"//xla/core/collectives:communicator",
293294
"//xla/core/collectives:rank_id",
294295
"//xla/service:global_device_id",
295296
"//xla/service:lockable",
296297
"//xla/service:rendezvous",
298+
"//xla/service/gpu:gpu_executable_run_options",
297299
"//xla/stream_executor:stream_executor_h",
298300
"@com_google_absl//absl/algorithm:container",
299301
"@com_google_absl//absl/base:core_headers",

0 commit comments

Comments
 (0)