Skip to content

Commit 9941e9a

Browse files
[XLA:GPU] Prevent all-reduce codegen when replica groups are empty
Generating collective code when participating devices are not specified is not possible unless topology information is available during compilation. This change bails out of codegen for empty replica_groups for this reason. PiperOrigin-RevId: 837190862
1 parent 4bac105 commit 9941e9a

File tree

3 files changed

+83
-43
lines changed

3 files changed

+83
-43
lines changed

xla/backends/gpu/codegen/triton/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,7 @@ cc_library(
10541054
"//xla/tsl/platform:statusor",
10551055
"@com_google_absl//absl/base",
10561056
"@com_google_absl//absl/container:flat_hash_map",
1057+
"@com_google_absl//absl/log",
10571058
"@com_google_absl//absl/status",
10581059
"@com_google_absl//absl/status:statusor",
10591060
"@com_google_absl//absl/strings",
@@ -1063,7 +1064,6 @@ cc_library(
10631064
"@llvm-project//mlir:IR",
10641065
"@llvm-project//mlir:NVVMDialect",
10651066
"@llvm-project//mlir:Support",
1066-
"@llvm-project//mlir:TensorDialect",
10671067
"@triton//:TritonDialects",
10681068
],
10691069
)
@@ -1080,7 +1080,6 @@ xla_cc_test(
10801080
"//xla:status_macros",
10811081
"//xla/backends/gpu/codegen:fusion_emitter",
10821082
"//xla/backends/gpu/codegen:fusions",
1083-
"//xla/hlo/analysis:symbolic_expr",
10841083
"//xla/hlo/ir:hlo",
10851084
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
10861085
"//xla/hlo/utils:hlo_query",
@@ -1095,6 +1094,7 @@ xla_cc_test(
10951094
"@com_google_absl//absl/status",
10961095
"@com_google_absl//absl/status:statusor",
10971096
"@com_google_absl//absl/strings:str_format",
1097+
"@com_google_absl//absl/strings:string_view",
10981098
"@com_google_googletest//:gtest_main",
10991099
"@llvm-project//llvm:ir_headers",
11001100
"@llvm-project//mlir:IR",

xla/backends/gpu/codegen/triton/collective_emitter.cc

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@ limitations under the License.
1919
#include <optional>
2020
#include <type_traits>
2121
#include <utility>
22+
#include <vector>
2223

2324
#include "absl/base/casts.h"
2425
#include "absl/container/flat_hash_map.h"
26+
#include "absl/log/log.h"
2527
#include "absl/status/status.h"
2628
#include "absl/status/statusor.h"
2729
#include "absl/strings/str_cat.h"
2830
#include "llvm/Support/MathExtras.h"
2931
#include "mlir/Dialect/Arith/IR/Arith.h"
3032
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
31-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
3233
#include "mlir/IR/BuiltinTypeInterfaces.h"
3334
#include "mlir/IR/BuiltinTypes.h"
35+
#include "mlir/IR/TypeUtilities.h"
3436
#include "mlir/IR/Types.h"
3537
#include "mlir/IR/Value.h"
3638
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -80,32 +82,68 @@ static constexpr auto kGlobalAddressSpace =
8082
mlir::NVVM::NVVMMemorySpace::Global);
8183

8284
// Metadata arguments for the collective emitter.
83-
// device_rank, signal-value, signal_buffers.
85+
// device_rank, signal_value, signal_buffers.
8486
static constexpr int32_t kNumCollectiveMetadataArgs = 3;
8587

86-
bool CanAllReduceBeEmitted(const HloAllReduceInstruction* all_reduce,
87-
ReductionKind reduction_kind, int64_t num_devices,
88-
int64_t num_elements, PrimitiveType element_type,
89-
AllReduceStrategy all_reduce_strategy) {
88+
struct AllReduceInfo {
89+
ReductionKind reduction_kind;
90+
int64_t num_devices;
91+
int64_t num_elements;
92+
PrimitiveType element_type;
93+
AllReduceStrategy all_reduce_strategy;
94+
};
95+
96+
// Returns the AllReduceInfo for the given all-reduce instruction if the
97+
// instruction is supported by the codegen.
98+
std::optional<AllReduceInfo> MaybeBuildAllReduceInfo(
99+
const HloAllReduceInstruction* all_reduce) {
90100
if (!all_reduce->GetModule()
91101
->config()
92102
.debug_options()
93103
.xla_gpu_unsupported_use_all_reduce_one_shot_kernel()) {
94-
return false;
104+
return std::nullopt;
105+
}
106+
if (all_reduce->device_list().replica_groups().empty()) {
107+
VLOG(1) << "Replica groups are empty for " << all_reduce->name()
108+
<< ". Codegen will not be supported.";
109+
return std::nullopt;
95110
}
111+
const int64_t num_devices = all_reduce->device_list().num_devices_per_group();
112+
const std::optional<ReductionKind> reduction_kind =
113+
MatchReductionComputation(all_reduce->called_computations().front());
114+
if (!reduction_kind.has_value()) {
115+
return std::nullopt;
116+
}
117+
const int64_t num_elements =
118+
ShapeUtil::ElementsIn(all_reduce->operand(0)->shape());
119+
const PrimitiveType element_type =
120+
all_reduce->operand(0)->shape().element_type();
121+
// NB: We do not codegen multimem kernels for now.
122+
const AllReduceStrategy all_reduce_strategy =
123+
GetAllReduceStrategy(num_elements, /*is_multimem_enabled=*/false);
96124
// TODO(b/383125489): Support variadic all-reduce.
97125
if (all_reduce->operand_count() > 1) {
98-
return false;
126+
return std::nullopt;
99127
}
100128
const int64_t byte_size =
101129
num_elements * ShapeUtil::ByteSizeOfPrimitiveType(element_type);
102130
// TODO(b/457333991): Support twoShot for codegen.
103131
if (byte_size >
104132
GetMaxSupportedAllReduceSizeBytes(AllReduceStrategy::kOneShot)) {
105-
return false;
133+
return std::nullopt;
106134
}
107-
return IsAllReduceKernelSupported(num_devices, num_elements, element_type,
108-
reduction_kind, all_reduce_strategy);
135+
if (!IsAllReduceKernelSupported(num_devices, num_elements, element_type,
136+
reduction_kind.value(),
137+
all_reduce_strategy)) {
138+
return std::nullopt;
139+
}
140+
return AllReduceInfo{
141+
/* .reduction_kind= */ reduction_kind.value(),
142+
/* .num_devices= */ num_devices,
143+
/* .num_elements= */ num_elements,
144+
/* .element_type= */ element_type,
145+
/* .all_reduce_strategy= */ all_reduce_strategy,
146+
};
109147
}
110148

111149
// The logic here is very naive and assumes a monotonic layout
@@ -114,27 +152,15 @@ absl::StatusOr<std::optional<BlockLevelFusionConfig>>
114152
GetBlockLevelFusionConfigForAllReduce(
115153
const se::DeviceDescription& device_info,
116154
const HloAllReduceInstruction* all_reduce) {
117-
const std::optional<ReductionKind> reduction_kind =
118-
MatchReductionComputation(all_reduce->called_computations().front());
119-
if (!reduction_kind.has_value()) {
120-
return absl::InternalError(
121-
"Reduction computation not found for all-reduce.");
122-
}
123-
const int64_t num_devices = all_reduce->device_list().num_devices_per_group();
124-
const int64_t num_elements =
125-
ShapeUtil::ElementsIn(all_reduce->operand(0)->shape());
126-
const PrimitiveType element_type =
127-
all_reduce->operand(0)->shape().element_type();
128-
// NB: We do not codegen multimem kernels for now.
129-
const AllReduceStrategy all_reduce_strategy =
130-
GetAllReduceStrategy(num_elements, /*is_multimem_enabled=*/false);
131-
if (!CanAllReduceBeEmitted(all_reduce, reduction_kind.value(), num_devices,
132-
num_elements, element_type, all_reduce_strategy)) {
155+
const std::optional<AllReduceInfo> all_reduce_info =
156+
MaybeBuildAllReduceInfo(all_reduce);
157+
if (!all_reduce_info.has_value()) {
133158
return std::nullopt;
134159
}
135160
const Shape& output_shape = all_reduce->shape();
136-
const LaunchDimensions launch_dims =
137-
AllReduceLaunchDimensions(num_elements, num_devices, all_reduce_strategy);
161+
const LaunchDimensions launch_dims = AllReduceLaunchDimensions(
162+
all_reduce_info->num_elements, all_reduce_info->num_devices,
163+
all_reduce_info->all_reduce_strategy);
138164
BlockLevelFusionConfig block_level_config;
139165
block_level_config.set_num_warps(launch_dims.num_threads_per_block() /
140166
WarpSize(device_info));
@@ -143,8 +169,8 @@ GetBlockLevelFusionConfigForAllReduce(
143169
Tile* output_tile = block_level_config.add_output_tiles();
144170
const int64_t rank = output_shape.dimensions().size();
145171

146-
// Tile sizes are rolled up to power of 2 because this is what the triton
147-
// expects (and consequently the tiling infra).
172+
// Tile sizes are rolled up to power of 2 because this is what triton expects
173+
// and consequently the tiling infra.
148174
for (int i = 0; i < rank - 1; ++i) {
149175
output_tile->add_sizes(llvm::PowerOf2Ceil(output_shape.dimensions(i)));
150176
}

xla/backends/gpu/codegen/triton/collective_emitter_test.cc

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "absl/status/status.h"
2828
#include "absl/status/statusor.h"
2929
#include "absl/strings/str_format.h"
30+
#include "absl/strings/string_view.h"
3031
#include "llvm/IR/Module.h"
3132
#include "mlir/IR/MLIRContext.h"
3233
#include "xla/backends/gpu/codegen/fusion_emitter.h"
@@ -85,8 +86,7 @@ class CollectiveBlockLevelConfigTest : public HloHardwareIndependentTestBase {
8586
: device_info_{TestGpuDeviceInfo::RTXH100SXMDeviceInfo()} {}
8687

8788
absl::StatusOr<ModuleWithFusion> BuildModuleWithFusion(
88-
const Shape& shape) const {
89-
const std::string module_str = GetModuleStr(shape);
89+
std::string module_str) const {
9090
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
9191
ParseAndReturnVerifiedModule(module_str));
9292
const HloInstruction* instr = hlo_query::GetFirstInstructionWithOpcode(
@@ -100,7 +100,8 @@ class CollectiveBlockLevelConfigTest : public HloHardwareIndependentTestBase {
100100
}
101101

102102
protected:
103-
static std::string GetModuleStr(const Shape& shape) {
103+
static std::string GetModuleStr(const Shape& shape,
104+
absl::string_view replica_groups = "{0,1}") {
104105
return absl::StrFormat(R"(
105106
HloModule test
106107
apply_op {
@@ -111,11 +112,11 @@ class CollectiveBlockLevelConfigTest : public HloHardwareIndependentTestBase {
111112
112113
ENTRY test_computation {
113114
param_0 = %1$s parameter(0)
114-
all-reduce-start = %1$s all-reduce-start(param_0), to_apply=apply_op, replica_groups={{0,1}}
115+
all-reduce-start = %1$s all-reduce-start(param_0), to_apply=apply_op, replica_groups={%2$s}
115116
ROOT all-reduce-done = %1$s all-reduce-done(all-reduce-start)
116117
}
117118
)",
118-
shape.ToString());
119+
shape.ToString(), replica_groups);
119120
}
120121

121122
const se::DeviceDescription device_info_;
@@ -124,9 +125,9 @@ class CollectiveBlockLevelConfigTest : public HloHardwareIndependentTestBase {
124125
class CollectiveEmitterTest : public CollectiveBlockLevelConfigTest {
125126
public:
126127
absl::StatusOr<std::unique_ptr<ModuleWithEmitter>> BuildModuleWithEmitter(
127-
const Shape& shape, const se::DeviceDescription& device_info) const {
128+
std::string module_str, const se::DeviceDescription& device_info) const {
128129
TF_ASSIGN_OR_RETURN(ModuleWithFusion module_with_fusion,
129-
BuildModuleWithFusion(shape));
130+
BuildModuleWithFusion(std::move(module_str)));
130131
TF_ASSIGN_OR_RETURN(
131132
bool collective_fusion_config_set,
132133
TrySetGpuBackendConfigForCollective(
@@ -174,7 +175,7 @@ class CollectiveEmitterParameterizedTest
174175
TEST_P(CollectiveEmitterParameterizedTest, AllReduceBlockLevelConfig) {
175176
const auto& param = GetParam();
176177
TF_ASSERT_OK_AND_ASSIGN(const auto module_with_fusion,
177-
BuildModuleWithFusion(param.shape));
178+
BuildModuleWithFusion(GetModuleStr(param.shape)));
178179
TF_ASSERT_OK_AND_ASSIGN(const auto block_level_config,
179180
GetCollectiveBlockLevelFusionConfig(
180181
device_info_, module_with_fusion.FusionInstr()));
@@ -207,10 +208,22 @@ INSTANTIATE_TEST_SUITE_P(
207208
return info.param.test_name;
208209
});
209210

211+
TEST_F(CollectiveEmitterTest, AllReduceBlockLevelConfigNoReplicaGroups) {
212+
TF_ASSERT_OK_AND_ASSIGN(
213+
const auto module_with_fusion,
214+
BuildModuleWithFusion(GetModuleStr(ShapeUtil::MakeShape(F32, {65536}),
215+
/* replica_groups= */ "")));
216+
TF_ASSERT_OK_AND_ASSIGN(const auto block_level_config,
217+
GetCollectiveBlockLevelFusionConfig(
218+
device_info_, module_with_fusion.FusionInstr()));
219+
EXPECT_EQ(block_level_config, std::nullopt);
220+
}
221+
210222
TEST_F(CollectiveEmitterTest, AllReduceWithTritonGetLaunchConfig) {
211223
TF_ASSERT_OK_AND_ASSIGN(
212224
std::unique_ptr<ModuleWithEmitter> result_ptr,
213-
BuildModuleWithEmitter(ShapeUtil::MakeShape(F32, {65536}), device_info_));
225+
BuildModuleWithEmitter(GetModuleStr(ShapeUtil::MakeShape(F32, {65536})),
226+
device_info_));
214227
auto& result = *result_ptr;
215228
const TritonFusion* triton_fusion = result.emitter.get();
216229
ASSERT_NE(triton_fusion, nullptr);
@@ -223,7 +236,8 @@ TEST_F(CollectiveEmitterTest, AllReduceWithTritonGetLaunchConfig) {
223236
TEST_F(CollectiveEmitterTest, AllReduceWithTritonGenerateTritonKernel) {
224237
TF_ASSERT_OK_AND_ASSIGN(
225238
std::unique_ptr<ModuleWithEmitter> result,
226-
BuildModuleWithEmitter(ShapeUtil::MakeShape(F32, {65536}), device_info_));
239+
BuildModuleWithEmitter(GetModuleStr(ShapeUtil::MakeShape(F32, {65536})),
240+
device_info_));
227241
const TritonFusion* triton_fusion = result->emitter.get();
228242
ASSERT_NE(triton_fusion, nullptr);
229243
TF_ASSERT_OK_AND_ASSIGN(

0 commit comments

Comments
 (0)