Skip to content

Commit 09bc536

Browse files
Title: New NCCL Collectives Latency Estimator
Description: This PR introduces a new analytical latency estimator for NCCL collectives, enabled via the next flags: --xla_gpu_enable_analytical_sol_latency_estimator \ --xla_gpu_analytical_latency_estimator_options='nccl_op_launch_us=<value>,nic_speed_gbps=<value>,chunk_prep_us=<value>,rtt_us=<value>,gpus_per_node=<value>,chunk_size_bytes=<value>' Replace <value> with appropriate number for your system (e.g., nccl_op_launch_us=XX). This estimator should improve accuracy and performance, especially for large-scale distributed training." PiperOrigin-RevId: 707261072
1 parent 2445c22 commit 09bc536

File tree

11 files changed

+748
-1
lines changed

11 files changed

+748
-1
lines changed

xla/debug_options_flags.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "absl/log/log.h"
3232
#include "absl/strings/ascii.h"
3333
#include "absl/strings/match.h"
34+
#include "absl/strings/numbers.h"
3435
#include "absl/strings/str_cat.h"
3536
#include "absl/strings/str_format.h"
3637
#include "absl/strings/str_join.h"
@@ -169,6 +170,25 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
169170
opts.set_xla_dump_latency_hiding_schedule(false);
170171
opts.set_xla_gpu_enable_latency_hiding_scheduler(false);
171172
opts.set_xla_gpu_enable_analytical_latency_estimator(false);
173+
opts.set_xla_gpu_enable_analytical_sol_latency_estimator(false);
174+
auto* sol_estimator_defaults =
175+
opts.mutable_xla_gpu_analytical_latency_estimator_options();
176+
sol_estimator_defaults->emplace(
177+
"nccl_op_launch_us",
178+
absl::StrCat(static_cast<int>(100.0f * kDefaultNcclCostModelCoeff)));
179+
sol_estimator_defaults->emplace(
180+
"nic_speed_gbps",
181+
absl::StrCat(static_cast<int>(55.56f * kDefaultNcclCostModelCoeff)));
182+
sol_estimator_defaults->emplace(
183+
"chunk_prep_us",
184+
absl::StrCat(static_cast<int>(13.34f * kDefaultNcclCostModelCoeff)));
185+
sol_estimator_defaults->emplace(
186+
"rtt_us",
187+
absl::StrCat(static_cast<int>(68.89f * kDefaultNcclCostModelCoeff)));
188+
sol_estimator_defaults->emplace(
189+
"chunk_size_bytes", absl::StrCat(kDefaultNcclCostModelChunkSizeBytes));
190+
sol_estimator_defaults->emplace(
191+
"gpus_per_node", absl::StrCat(kDefaultNcclCostModelGPUsPerNode));
172192
opts.set_xla_gpu_pgle_profile_file_or_directory_path("");
173193
opts.set_xla_gpu_memory_limit_slop_factor(95);
174194
opts.set_xla_gpu_enable_highest_priority_async_stream(true);
@@ -470,6 +490,17 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
470490
return true;
471491
};
472492

493+
// Custom "sub-parser" lambda for
494+
// xla_gpu_analytical_latency_estimator_options.
495+
auto setter_for_xla_gpu_analytical_latency_estimator_options =
496+
[debug_options](std::string comma_separated_values) {
497+
google::protobuf::Map<std::string, std::string>* options_map =
498+
debug_options
499+
->mutable_xla_gpu_analytical_latency_estimator_options();
500+
parse_xla_backend_extra_options(options_map, comma_separated_values);
501+
return true;
502+
};
503+
473504
// Custom "sub-parser" lambda for xla_partitioning_algorithm.
474505
auto setter_for_xla_partitioning_algorithm =
475506
[debug_options](const std::string& value) {
@@ -1568,6 +1599,25 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
15681599
debug_options->xla_gpu_enable_analytical_latency_estimator(),
15691600
"Enable analytical latency estimator for latency-hiding scheduler for "
15701601
"XLA:GPU"));
1602+
flag_list->push_back(tsl::Flag(
1603+
"xla_gpu_enable_analytical_sol_latency_estimator",
1604+
bool_setter_for(
1605+
&DebugOptions::set_xla_gpu_enable_analytical_sol_latency_estimator),
1606+
debug_options->xla_gpu_enable_analytical_sol_latency_estimator(),
1607+
"Enable analytical Speed-of-Light latency estimator for latency-hiding "
1608+
"scheduler for XLA:GPU, must be used without "
1609+
"xla_gpu_enable_analytical_latency_estimator. It can also benefit from "
1610+
"user-passed options in xla_gpu_analytical_latency_estimator_options"));
1611+
flag_list->push_back(tsl::Flag(
1612+
"xla_gpu_analytical_latency_estimator_options",
1613+
setter_for_xla_gpu_analytical_latency_estimator_options, "",
1614+
"Extra platform-specific options to improve analytical latency "
1615+
"estimator precision; comma-separated list of 'key=val' "
1616+
"strings (=val may be omitted); no whitespace around commas."
1617+
"Available options: "
1618+
"--xla_gpu_analytical_latency_estimator_options='nccl_op_launch_ms=55,"
1619+
"nic_speed_gbps=40,chunk_prep_ms=1,rtt_ms=2,gpus_per_node=4,"
1620+
"chunk_size_bytes=1024'"));
15711621
flag_list->push_back(tsl::Flag(
15721622
"xla_gpu_pgle_profile_file_or_directory_path",
15731623
string_setter_for(

xla/service/collective_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ constexpr int64_t kDefaultAllGatherCombineThreshold = 30 * 1024 * 1024 + 7;
3232
// pass will combine collectives.
3333
constexpr int64_t kDefaultReduceScatterCombineThreshold = 30 * 1024 * 1024 + 7;
3434

35+
// Defines the default coefficient for the SoL NCCL collective cost model.
36+
// Note: XLA flags allow a user to override the default values of the model.
37+
constexpr float kDefaultNcclCostModelCoeff = 0.45f;
38+
constexpr int64_t kDefaultNcclCostModelChunkSizeBytes = 4194304; // 4MB
39+
constexpr int64_t kDefaultNcclCostModelGPUsPerNode = 8;
3540
} // namespace xla
3641

3742
#endif // XLA_SERVICE_COLLECTIVE_UTILS_H_

xla/service/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,6 +2120,7 @@ cc_library(
21202120
"//xla/service:p2p_schedule_preparation",
21212121
"//xla/service:profile_guided_latency_estimator",
21222122
"//xla/service/gpu/model:analytical_latency_estimator",
2123+
"//xla/service/gpu/model:sol_latency_estimator",
21232124
"//xla/service/gpu/transforms:pgle_accuracy_checker",
21242125
"//xla/service/gpu/transforms:schedule_postprocessing",
21252126
"//xla/service/gpu/transforms:scheduling_instruction_annotator",

xla/service/gpu/gpu_hlo_schedule.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ limitations under the License.
5151
#include "xla/service/gpu/flag_utils.h"
5252
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
5353
#include "xla/service/gpu/model/analytical_latency_estimator.h"
54+
#include "xla/service/gpu/model/sol_latency_estimator.h"
5455
#include "xla/service/gpu/transforms/pgle_accuracy_checker.h"
5556
#include "xla/service/gpu/transforms/schedule_postprocessing.h"
5657
#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h"
@@ -496,6 +497,16 @@ std::unique_ptr<LatencyEstimator> GetLatencyEstimator(
496497
},
497498
module.entry_computation());
498499
}
500+
501+
if (options.xla_gpu_enable_analytical_sol_latency_estimator()) {
502+
LOG(INFO) << "Using Speed-of-Light (SoL) analytical latency estimator";
503+
return std::make_unique<SolLatencyEstimator>(
504+
config, std::move(gpu_latency_estimator), gpu_device_info,
505+
[input_pointer_size = pointer_size](const Shape& shape) {
506+
return GetSizeOfShape(shape, input_pointer_size);
507+
},
508+
module.entry_computation());
509+
}
499510
return gpu_latency_estimator;
500511
}
501512

xla/service/gpu/model/BUILD

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,77 @@ cc_library(
4343
],
4444
)
4545

46+
cc_library(
47+
name = "sol_latency_estimator",
48+
srcs = ["sol_latency_estimator.cc"],
49+
hdrs = ["sol_latency_estimator.h"],
50+
deps = [
51+
":coalescing_analysis",
52+
":fusion_analysis_cache",
53+
":gpu_hlo_cost_analysis",
54+
":gpu_performance_model",
55+
":gpu_performance_model_base",
56+
":hlo_op_profiles",
57+
":sol_gpu_cost_model",
58+
"//xla:shape_util",
59+
"//xla:util",
60+
"//xla/hlo/analysis:hlo_dataflow_analysis",
61+
"//xla/hlo/analysis:indexing_analysis",
62+
"//xla/hlo/ir:hlo",
63+
"//xla/hlo/utils:hlo_query",
64+
"//xla/hlo/utils:hlo_traversal",
65+
"//xla/service:hlo_cost_analysis",
66+
"//xla/service:latency_hiding_scheduler",
67+
"//xla/service/gpu:backend_configs_cc",
68+
"//xla/service/gpu:gpu_fusible",
69+
"//xla/service/gpu:hlo_fusion_analysis",
70+
"//xla/service/gpu:launch_dimensions",
71+
"//xla/service/gpu/fusions",
72+
"//xla/service/gpu/fusions:fusion_emitter",
73+
"//xla/stream_executor:device_description",
74+
"@com_google_absl//absl/container:flat_hash_map",
75+
"@com_google_absl//absl/container:flat_hash_set",
76+
"@com_google_absl//absl/log",
77+
"@com_google_absl//absl/log:check",
78+
"@com_google_absl//absl/status:statusor",
79+
"@com_google_absl//absl/strings",
80+
"@com_google_absl//absl/synchronization",
81+
"@com_google_absl//absl/time",
82+
"@com_google_absl//absl/types:span",
83+
"@llvm-project//llvm:Support",
84+
"@llvm-project//mlir:IR",
85+
"@tsl//tsl/platform:errors",
86+
"@tsl//tsl/platform:status",
87+
],
88+
)
89+
90+
cc_library(
91+
name = "sol_gpu_cost_model",
92+
srcs = ["sol_gpu_cost_model.cc"],
93+
hdrs = ["sol_gpu_cost_model.h"],
94+
deps = [
95+
"//xla/hlo/ir:hlo",
96+
"@com_google_absl//absl/log",
97+
"@com_google_absl//absl/log:check",
98+
"@com_google_absl//absl/numeric:bits",
99+
"@com_google_absl//absl/strings",
100+
"@com_google_absl//absl/time",
101+
],
102+
)
103+
104+
xla_cc_test(
105+
name = "sol_gpu_cost_model_test",
106+
srcs = ["sol_gpu_cost_model_test.cc"],
107+
deps = [
108+
":sol_gpu_cost_model",
109+
"//xla/tests:xla_internal_test_main",
110+
"@com_google_absl//absl/strings",
111+
"@com_google_absl//absl/strings:string_view",
112+
"@com_google_absl//absl/time",
113+
"@com_google_googletest//:gtest",
114+
],
115+
)
116+
46117
xla_test(
47118
name = "analytical_latency_estimator_test",
48119
srcs = ["analytical_latency_estimator_test.cc"],
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/service/gpu/model/sol_gpu_cost_model.h"
17+
18+
#include <cmath>
19+
#include <cstdint>
20+
#include <string>
21+
22+
#include "absl/log/check.h"
23+
#include "absl/log/log.h"
24+
#include "absl/numeric/bits.h"
25+
#include "absl/strings/numbers.h"
26+
#include "absl/strings/string_view.h"
27+
#include "absl/time/time.h"
28+
#include "xla/hlo/ir/hlo_module.h"
29+
30+
namespace xla {
31+
namespace gpu {
32+
namespace {
33+
// Constants for NCCL SoL model
34+
constexpr double kHeaderOverhead = 0.025;
35+
constexpr absl::string_view kNcclOpLaunchUs = "nccl_op_launch_us";
36+
constexpr absl::string_view kNicSpeedGbps = "nic_speed_gbps";
37+
constexpr absl::string_view kChunkPrepUs = "chunk_prep_us";
38+
constexpr absl::string_view kRttUs = "rtt_us";
39+
constexpr absl::string_view kGpusPerNode = "gpus_per_node";
40+
constexpr absl::string_view kChunkSizeBytes = "chunk_size_bytes";
41+
42+
// Returns the number of communicators in the mask.
43+
// For example, if the mask is 0x0, this function returns 1. If the mask is 0x7,
44+
// this function returns 8.
45+
int NumCommunicators(const absl::string_view mask) {
46+
// Assuming the mask is a hexadecimal number
47+
uint64_t mask_value = std::stoul(std::string(mask), nullptr, 16);
48+
int bit_count = absl::popcount(mask_value); // Count set bits
49+
return static_cast<int>(std::pow(2, bit_count));
50+
}
51+
52+
// Returns the number of rounds for the given collective type.
53+
int NumRounds(const SolGPUCostModel::CollectiveType& coll_type) {
54+
// AllReduce requires ReduceScatter and AllGather, so it has 2 rounds.
55+
return coll_type == SolGPUCostModel::CollectiveType::kAllReduce ? 2 : 1;
56+
}
57+
58+
} // namespace
59+
60+
SolGPUCostModel::Config GetConfig(const HloModule* module) {
61+
SolGPUCostModel::Config config;
62+
const auto& extra_options =
63+
module->config()
64+
.debug_options()
65+
.xla_gpu_analytical_latency_estimator_options();
66+
for (const auto& [option_name, option_value] : extra_options) {
67+
int64_t value;
68+
double value_d;
69+
VLOG(2) << "[SoL] option: " << option_name << " is " << option_value;
70+
if (option_name == kNcclOpLaunchUs &&
71+
absl::SimpleAtoi(option_value, &value)) {
72+
config.nccl_op_launch_time = absl::Microseconds(value);
73+
} else if (option_name == kNicSpeedGbps &&
74+
absl::SimpleAtod(option_value, &value_d)) {
75+
config.nic_speed_gbps = value_d;
76+
} else if (option_name == kChunkPrepUs &&
77+
absl::SimpleAtoi(option_value, &value)) {
78+
config.chunk_prep_time = absl::Microseconds(value);
79+
} else if (option_name == kRttUs &&
80+
absl::SimpleAtoi(option_value, &value)) {
81+
config.rtt = absl::Microseconds(value);
82+
} else if (option_name == kGpusPerNode &&
83+
absl::SimpleAtoi(option_value, &value)) {
84+
config.gpus_per_node = value;
85+
} else if (option_name == kChunkSizeBytes &&
86+
absl::SimpleAtoi(option_value, &value)) {
87+
config.chunk_size_bytes = value;
88+
}
89+
}
90+
return config;
91+
}
92+
93+
SolGPUCostModel::SolGPUCostModel(const Config& sys_config)
94+
: xla_flag_config_(sys_config) {
95+
VLOG(2) << "[SoL] NIC speed: " << xla_flag_config_.nic_speed_gbps;
96+
VLOG(2) << "[SoL] RTT: " << xla_flag_config_.rtt;
97+
VLOG(2) << "[SoL] Chunk preparation time: "
98+
<< xla_flag_config_.chunk_prep_time;
99+
VLOG(2) << "[SoL] NCCL op launch time: "
100+
<< xla_flag_config_.nccl_op_launch_time;
101+
VLOG(2) << "[SoL] GPUs per node: " << xla_flag_config_.gpus_per_node;
102+
}
103+
104+
// This is a insignificant term, and we are making it consistent
105+
// with the existing formula.
106+
absl::Duration SolGPUCostModel::ChunkPrepLatency(
107+
const int64_t per_gpu_msg_size_bytes) const {
108+
return std::ceil(static_cast<double>(per_gpu_msg_size_bytes) /
109+
xla_flag_config_.chunk_size_bytes) *
110+
xla_flag_config_.chunk_prep_time;
111+
}
112+
113+
absl::Duration SolGPUCostModel::TransferDuration(
114+
const int64_t per_gpu_msg_size_bytes) const {
115+
// x1e6 to comvert secs to microseconds;
116+
// x1024*1024 *1024 to convert Gbytes/sec to bytes/sec
117+
const long double ret =
118+
(1e6 * static_cast<long double>(per_gpu_msg_size_bytes)) /
119+
(std::pow(1024.0, 3) * xla_flag_config_.nic_speed_gbps);
120+
return absl::Microseconds(ret * (1 + kHeaderOverhead));
121+
}
122+
123+
absl::Duration SolGPUCostModel::RingLatency(
124+
const int64_t buff_size_bytes, const int num_nodes,
125+
const CollectiveType& coll_type, const absl::string_view mask) const {
126+
const int num_gpus = NumGpusPerComm(num_nodes, coll_type, mask);
127+
128+
int64_t per_gpu_msg_size_bytes;
129+
if (coll_type == CollectiveType::kSendRecv) {
130+
per_gpu_msg_size_bytes = buff_size_bytes;
131+
} else {
132+
per_gpu_msg_size_bytes = buff_size_bytes / num_gpus;
133+
}
134+
135+
// This is the number of GPUs per communicator per node. We assume that each
136+
// GPU has a NIC, and this is also the number of NICs per communicator per
137+
// node.
138+
// Note that this happens to be correct value (i.e. 1) for SendRecv.
139+
int num_gpus_per_node = num_gpus / num_nodes;
140+
141+
// In each channel, consider one GPU next to the Ethernet link. Below is the
142+
// sum of 3 time costs for each piece of data of size
143+
// `per_gpu_msg_size_bytes`
144+
//
145+
// 1. transfer duration defined by the NIC bandwidth,
146+
// 2. chunk preparation latency, and
147+
// 3. RTT
148+
//
149+
// then followed by two factors:
150+
//
151+
// 1. Multiply by `num_gpus - 1`, as `num_gpus - 1` pieces of data will be
152+
// sent over the link in AllGather.
153+
// 2. Divide by `num_gpus_per_node` as there are `num_gpus_per_node` NICs
154+
// and
155+
// GPUs in each node for parallelism.
156+
//
157+
// Better estimates of terms like this will come in future versions
158+
// of the SoL model.
159+
absl::Duration ret = TransferDuration(per_gpu_msg_size_bytes) +
160+
ChunkPrepLatency(per_gpu_msg_size_bytes) +
161+
xla_flag_config_.rtt;
162+
ret *= (num_gpus - 1.0) / static_cast<long double>(num_gpus_per_node);
163+
// Multiply by the number of rounds, which is different for AllReduce.
164+
ret = ret * NumRounds(coll_type);
165+
166+
// Time to initiate the collective.
167+
return ret + xla_flag_config_.nccl_op_launch_time;
168+
}
169+
170+
// Helper functions
171+
int SolGPUCostModel::NumGpusPerComm(int num_nodes,
172+
const CollectiveType& coll_type,
173+
const absl::string_view mask) const {
174+
if (coll_type == CollectiveType::kSendRecv) {
175+
return 2;
176+
}
177+
int num_comms = NumCommunicators(mask);
178+
CHECK_EQ(xla_flag_config_.gpus_per_node % num_comms, 0)
179+
<< "GPU_PER_NODE must be divisible by the number of communicators. "
180+
"GPU_PER_NODE: "
181+
<< xla_flag_config_.gpus_per_node
182+
<< " Number of communicators: " << num_comms
183+
<< ". Adjust the number of GPUs per node with the flag "
184+
"gpus_per_node in xla_gpu_analytical_latency_estimator_options.";
185+
return num_nodes * xla_flag_config_.gpus_per_node / num_comms;
186+
}
187+
188+
} // namespace gpu
189+
} // namespace xla

0 commit comments

Comments
 (0)