@@ -16,15 +16,21 @@ limitations under the License.
1616#ifndef XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
1717#define XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
1818
19+ #include < functional>
1920#include < map>
2021#include < optional>
2122
23+ #include " absl/status/statusor.h"
24+ #include " xla/core/collectives/clique_id.h"
25+ #include " xla/core/collectives/clique_key.h"
2226#include " xla/executable_run_options.h"
2327#include " xla/service/global_device_id.h"
24- #include " xla/service/gpu/runtime/nccl_clique_key.h"
2528
26- namespace xla {
27- namespace gpu {
29+ namespace xla ::gpu {
30+
31+ // A callback to get a unique clique id.
32+ using CliqueIdCallback = // NOLINT
33+ std::function<absl::StatusOr<CliqueId>(const CliqueKey&)>;
2834
2935// GPU-specific executable options.
3036// We keep these separate from ExecutableRunOptions to avoid adding
@@ -40,11 +46,10 @@ class GpuExecutableRunOptions {
4046 const std::optional<std::map<int , GlobalDeviceId>>& gpu_global_device_ids ()
4147 const ;
4248
43- // Callback that returns a ncclUniqueId encoded as a string for a group of
44- // communicating GPU devices. Used only on NVidia GPUs.
45- GpuExecutableRunOptions& set_nccl_clique_id_callback (
46- NcclCliqueIdCallback nccl_clique_id_callback);
47- const NcclCliqueIdCallback& nccl_clique_id_callback () const ;
49+ // Callback that returns a unique clieque id for a given clique key.
50+ GpuExecutableRunOptions& set_clique_id_callback (
51+ CliqueIdCallback clique_id_callback);
52+ const CliqueIdCallback& clique_id_callback () const ;
4853
4954 // Whether the run requires an exclusive lock on the GPU.
5055 bool requires_exclusive_lock_on_gpu () const {
@@ -57,24 +62,21 @@ class GpuExecutableRunOptions {
5762 return *this ;
5863 }
5964
60- bool enable_mock_nccl_collectives () const {
61- return enable_mock_nccl_collectives_;
62- }
65+ bool enable_mock_collectives () const { return enable_mock_collectives_; }
6366
6467 // Enables mocking nccl collective operations on the GPU.
65- GpuExecutableRunOptions& set_enable_mock_nccl_collectives () {
66- enable_mock_nccl_collectives_ = true ;
68+ GpuExecutableRunOptions& set_enable_mock_collectives () {
69+ enable_mock_collectives_ = true ;
6770 return *this ;
6871 }
6972
7073 private:
7174 bool requires_exclusive_lock_on_gpu_ = false ;
72- bool enable_mock_nccl_collectives_ = false ;
75+ bool enable_mock_collectives_ = false ;
7376 std::optional<std::map<int , GlobalDeviceId>> gpu_global_device_ids_;
74- NcclCliqueIdCallback nccl_clique_id_callback_ ;
77+ CliqueIdCallback clique_id_callback_ ;
7578};
7679
77- } // namespace gpu
78- } // namespace xla
80+ } // namespace xla::gpu
7981
8082#endif // XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
0 commit comments