@@ -19,18 +19,20 @@ limitations under the License.
1919#include < optional>
2020#include < type_traits>
2121#include < utility>
22+ #include < vector>
2223
2324#include " absl/base/casts.h"
2425#include " absl/container/flat_hash_map.h"
26+ #include " absl/log/log.h"
2527#include " absl/status/status.h"
2628#include " absl/status/statusor.h"
2729#include " absl/strings/str_cat.h"
2830#include " llvm/Support/MathExtras.h"
2931#include " mlir/Dialect/Arith/IR/Arith.h"
3032#include " mlir/Dialect/LLVMIR/NVVMDialect.h"
31- #include " mlir/Dialect/Tensor/IR/Tensor.h"
3233#include " mlir/IR/BuiltinTypeInterfaces.h"
3334#include " mlir/IR/BuiltinTypes.h"
35+ #include " mlir/IR/TypeUtilities.h"
3436#include " mlir/IR/Types.h"
3537#include " mlir/IR/Value.h"
3638#include " mlir/Interfaces/FunctionInterfaces.h"
@@ -80,32 +82,68 @@ static constexpr auto kGlobalAddressSpace =
8082 mlir::NVVM::NVVMMemorySpace::Global);
8183
8284// Metadata arguments for the collective emitter.
83- // device_rank, signal-value , signal_buffers.
85+ // device_rank, signal_value , signal_buffers.
8486static constexpr int32_t kNumCollectiveMetadataArgs = 3 ;
8587
86- bool CanAllReduceBeEmitted (const HloAllReduceInstruction* all_reduce,
87- ReductionKind reduction_kind, int64_t num_devices,
88- int64_t num_elements, PrimitiveType element_type,
89- AllReduceStrategy all_reduce_strategy) {
88+ struct AllReduceInfo {
89+ ReductionKind reduction_kind;
90+ int64_t num_devices;
91+ int64_t num_elements;
92+ PrimitiveType element_type;
93+ AllReduceStrategy all_reduce_strategy;
94+ };
95+
96+ // Returns the AllReduceInfo for the given all-reduce instruction if the
97+ // instruction is supported by the codegen.
98+ std::optional<AllReduceInfo> MaybeBuildAllReduceInfo (
99+ const HloAllReduceInstruction* all_reduce) {
90100 if (!all_reduce->GetModule ()
91101 ->config ()
92102 .debug_options ()
93103 .xla_gpu_unsupported_use_all_reduce_one_shot_kernel ()) {
94- return false ;
104+ return std::nullopt ;
105+ }
106+ if (all_reduce->device_list ().replica_groups ().empty ()) {
107+ VLOG (1 ) << " Replica groups are empty for " << all_reduce->name ()
108+ << " . Codegen will not be supported." ;
109+ return std::nullopt ;
95110 }
111+ const int64_t num_devices = all_reduce->device_list ().num_devices_per_group ();
112+ const std::optional<ReductionKind> reduction_kind =
113+ MatchReductionComputation (all_reduce->called_computations ().front ());
114+ if (!reduction_kind.has_value ()) {
115+ return std::nullopt ;
116+ }
117+ const int64_t num_elements =
118+ ShapeUtil::ElementsIn (all_reduce->operand (0 )->shape ());
119+ const PrimitiveType element_type =
120+ all_reduce->operand (0 )->shape ().element_type ();
121+ // NB: We do not codegen multimem kernels for now.
122+ const AllReduceStrategy all_reduce_strategy =
123+ GetAllReduceStrategy (num_elements, /* is_multimem_enabled=*/ false );
96124 // TODO(b/383125489): Support variadic all-reduce.
97125 if (all_reduce->operand_count () > 1 ) {
98- return false ;
126+ return std:: nullopt ;
99127 }
100128 const int64_t byte_size =
101129 num_elements * ShapeUtil::ByteSizeOfPrimitiveType (element_type);
102130 // TODO(b/457333991): Support twoShot for codegen.
103131 if (byte_size >
104132 GetMaxSupportedAllReduceSizeBytes (AllReduceStrategy::kOneShot )) {
105- return false ;
133+ return std:: nullopt ;
106134 }
107- return IsAllReduceKernelSupported (num_devices, num_elements, element_type,
108- reduction_kind, all_reduce_strategy);
135+ if (!IsAllReduceKernelSupported (num_devices, num_elements, element_type,
136+ reduction_kind.value (),
137+ all_reduce_strategy)) {
138+ return std::nullopt ;
139+ }
140+ return AllReduceInfo{
141+ /* .reduction_kind= */ reduction_kind.value (),
142+ /* .num_devices= */ num_devices,
143+ /* .num_elements= */ num_elements,
144+ /* .element_type= */ element_type,
145+ /* .all_reduce_strategy= */ all_reduce_strategy,
146+ };
109147}
110148
111149// The logic here is very naive and assumes a monotonic layout
@@ -114,27 +152,15 @@ absl::StatusOr<std::optional<BlockLevelFusionConfig>>
114152GetBlockLevelFusionConfigForAllReduce (
115153 const se::DeviceDescription& device_info,
116154 const HloAllReduceInstruction* all_reduce) {
117- const std::optional<ReductionKind> reduction_kind =
118- MatchReductionComputation (all_reduce->called_computations ().front ());
119- if (!reduction_kind.has_value ()) {
120- return absl::InternalError (
121- " Reduction computation not found for all-reduce." );
122- }
123- const int64_t num_devices = all_reduce->device_list ().num_devices_per_group ();
124- const int64_t num_elements =
125- ShapeUtil::ElementsIn (all_reduce->operand (0 )->shape ());
126- const PrimitiveType element_type =
127- all_reduce->operand (0 )->shape ().element_type ();
128- // NB: We do not codegen multimem kernels for now.
129- const AllReduceStrategy all_reduce_strategy =
130- GetAllReduceStrategy (num_elements, /* is_multimem_enabled=*/ false );
131- if (!CanAllReduceBeEmitted (all_reduce, reduction_kind.value (), num_devices,
132- num_elements, element_type, all_reduce_strategy)) {
155+ const std::optional<AllReduceInfo> all_reduce_info =
156+ MaybeBuildAllReduceInfo (all_reduce);
157+ if (!all_reduce_info.has_value ()) {
133158 return std::nullopt ;
134159 }
135160 const Shape& output_shape = all_reduce->shape ();
136- const LaunchDimensions launch_dims =
137- AllReduceLaunchDimensions (num_elements, num_devices, all_reduce_strategy);
161+ const LaunchDimensions launch_dims = AllReduceLaunchDimensions (
162+ all_reduce_info->num_elements , all_reduce_info->num_devices ,
163+ all_reduce_info->all_reduce_strategy );
138164 BlockLevelFusionConfig block_level_config;
139165 block_level_config.set_num_warps (launch_dims.num_threads_per_block () /
140166 WarpSize (device_info));
@@ -143,8 +169,8 @@ GetBlockLevelFusionConfigForAllReduce(
143169 Tile* output_tile = block_level_config.add_output_tiles ();
144170 const int64_t rank = output_shape.dimensions ().size ();
145171
146- // Tile sizes are rolled up to power of 2 because this is what the triton
147- // expects ( and consequently the tiling infra) .
172+ // Tile sizes are rolled up to power of 2 because this is what triton expects
173+ // and consequently the tiling infra.
148174 for (int i = 0 ; i < rank - 1 ; ++i) {
149175 output_tile->add_sizes (llvm::PowerOf2Ceil (output_shape.dimensions (i)));
150176 }
0 commit comments