@@ -81,6 +81,7 @@ using emitters::PartitionedComputations;
8181using emitters::ProvideParameter;
8282using llvm::APFloat;
8383using llvm::APInt;
84+ using llvm::ArrayRef;
8485using llvm::SmallVector;
8586using mlir::AffineExpr;
8687using mlir::AffineMap;
@@ -171,7 +172,7 @@ Value UpdateIsInbounds(ImplicitLocOpBuilder& b, Value is_inbounds,
171172 .front ();
172173}
173174
174- SmallVector<Value> Pack (llvm:: ArrayRef<ValueRange> ranges) {
175+ SmallVector<Value> Pack (ArrayRef<ValueRange> ranges) {
175176 int64_t total_size = 0 ;
176177 for (auto & range : ranges) {
177178 total_size += range.size ();
@@ -184,8 +185,7 @@ SmallVector<Value> Pack(llvm::ArrayRef<ValueRange> ranges) {
184185 return result;
185186}
186187
187- SmallVector<ValueRange> Unpack (ValueRange range,
188- llvm::ArrayRef<int64_t > sizes) {
188+ SmallVector<ValueRange> Unpack (ValueRange range, ArrayRef<int64_t > sizes) {
189189 int64_t total_size = 0 ;
190190 for (auto & size : sizes) {
191191 total_size += size;
@@ -213,29 +213,42 @@ SmallVector<Value, 4> PadWithZeros(ValueRange values, int64_t size,
213213}
214214
215215// Creates a new indexing map that is the same as `map` but with the range
216- // variable at `range_var_index` replaced with the new dimension variable at
217- // `dimension_{dim_var_size)`. Potentially, it can be moved to indexing_map.h.
218- IndexingMap ConvertRangeVariableToDimension (const IndexingMap& map,
219- int64_t range_var_index) {
216+ // variables at `range_var_indices` converted to the new dimensions variables at
217+ // and added to the end of dimension variables list. Potentially, it can be
218+ // moved to indexing_map.h.
219+ IndexingMap ConvertRangeVariableToDimension (
220+ const IndexingMap& map, ArrayRef<int64_t > range_var_indices) {
221+ CHECK (std::is_sorted (range_var_indices.begin (), range_var_indices.end ()));
220222 auto * mlir_context = map.GetMLIRContext ();
221223
222224 AffineMap affine_map = map.GetAffineMap ();
223- // Update the affine map.
225+ // Update the affine map and the variables.
226+ std::vector<IndexingMap::Variable> dims = map.GetDimVars ();
227+ std::vector<IndexingMap::Variable> range_vars;
228+ std::vector<IndexingMap::Variable> rt_vars = map.GetRTVars ();
224229 SmallVector<AffineExpr, 4 > symbol_replacements;
225230 symbol_replacements.reserve (affine_map.getNumSymbols ());
231+ int64_t range_var_count = 0 ;
232+ int64_t range_var_indices_count = range_var_indices.size ();
226233 for (int i = 0 ; i < affine_map.getNumSymbols (); ++i) {
227- if (i == range_var_index) {
234+ auto range_var = map.GetRangeVar (i);
235+ if (range_var_count < range_var_indices_count &&
236+ i == range_var_indices[range_var_count]) {
228237 symbol_replacements.push_back (
229238 getAffineDimExpr (affine_map.getNumDims (), mlir_context));
239+ dims.push_back (range_var);
240+ range_var_count++;
230241 } else {
231242 symbol_replacements.push_back (
232- getAffineSymbolExpr (i > range_var_index ? i - 1 : i, mlir_context));
243+ getAffineSymbolExpr (i - range_var_count, mlir_context));
244+ range_vars.push_back (range_var);
233245 }
234246 }
235247
236248 AffineMap converted_affine_map = affine_map.replaceDimsAndSymbols (
237- {}, symbol_replacements, affine_map.getNumDims () + 1 ,
238- affine_map.getNumSymbols () - 1 );
249+ {}, symbol_replacements,
250+ affine_map.getNumDims () + range_var_indices_count,
251+ affine_map.getNumSymbols () - range_var_indices_count);
239252
240253 // Update the constraints.
241254 std::vector<std::pair<AffineExpr, Interval>> constraints;
@@ -244,13 +257,6 @@ IndexingMap ConvertRangeVariableToDimension(const IndexingMap& map,
244257 constraints.push_back ({constraint.first .replaceSymbols (symbol_replacements),
245258 constraint.second });
246259 }
247- // Update the variables.
248- std::vector<IndexingMap::Variable> dims = map.GetDimVars ();
249- std::vector<IndexingMap::Variable> range_vars = map.GetRangeVars ();
250- std::vector<IndexingMap::Variable> rt_vars = map.GetRTVars ();
251-
252- dims.push_back (range_vars[range_var_index]);
253- range_vars.erase (range_vars.begin () + range_var_index);
254260 return IndexingMap{converted_affine_map, std::move (dims),
255261 std::move (range_vars), std::move (rt_vars), constraints};
256262}
@@ -299,7 +305,8 @@ class EmitterHelper {
299305
300306 Value WriteAccumulatorToOutput (ImplicitLocOpBuilder& b,
301307 Value write_to_output_required,
302- ValueRange thread_and_block_ids, Value iv,
308+ ValueRange thread_and_block_ids,
309+ Value index_id,
303310 const IndexingMap& slice_indexing,
304311 ValueRange offsets, Value accumulator,
305312 Value output_tensor) const ;
@@ -370,10 +377,10 @@ SmallVector<Value> EmitterHelper::WriteAccumulatedElementToOutput(
370377
371378Value EmitterHelper::WriteAccumulatorToOutput (
372379 ImplicitLocOpBuilder& b, Value write_to_output_required,
373- ValueRange thread_and_block_ids, Value iv ,
380+ ValueRange thread_and_block_ids, Value index_id ,
374381 const IndexingMap& slice_indexing, ValueRange offsets, Value accumulator,
375382 Value output_tensor) const {
376- SmallVector<Value> dims = Pack ({thread_and_block_ids, iv });
383+ SmallVector<Value> dims = Pack ({thread_and_block_ids, index_id });
377384 return EmitUpdateIf (
378385 b, write_to_output_required, output_tensor,
379386 [&](ImplicitLocOpBuilder& if_builder) -> SmallVector<Value> {
@@ -569,9 +576,9 @@ absl::Status ScatterWithDistributedUpdates::EmitEntryFunctionImpl(
569576 ValueRange thread_and_block_ids, Value output_tensor) const {
570577 if (VLOG_IS_ON (5 )) {
571578 llvm::errs () << " Settings for ScatterWithDistributedUpdates: \n "
572- << " vector_size_ : " << vector_size_ << " \n "
573- << " num_warps_ : " << num_warps_ << " \n "
574- << " num_blocks_ : " << num_blocks_;
579+ << " vector_size : " << vector_size_ << " \n "
580+ << " num_warps : " << num_warps_ << " \n "
581+ << " num_blocks : " << num_blocks_ << " \n " ;
575582 }
576583 EmitNaiveImplementation (b, description_, helper, updates_map, indices_map,
577584 thread_and_block_ids, output_tensor);
@@ -581,10 +588,11 @@ absl::Status ScatterWithDistributedUpdates::EmitEntryFunctionImpl(
581588ScatterWithDistributedIndices::ScatterWithDistributedIndices (
582589 const HloFusionAnalysis& analysis, const ScatterDescription& description,
583590 int64_t vector_size, int64_t num_warps_per_slice,
584- int64_t num_indices_per_warp)
591+ int64_t num_indices_per_warp, int64_t indices_vector_size )
585592 : ScatterFusion(analysis, description, vector_size),
586593 num_warps_per_slice_(num_warps_per_slice),
587- num_indices_per_warp_(num_indices_per_warp) {
594+ num_indices_per_warp_(num_indices_per_warp),
595+ indices_vector_size_(indices_vector_size) {
588596 num_warps_ = kNumWarpsPerBlock ;
589597 num_blocks_ = CeilOfRatio (description.num_slices * num_warps_per_slice_,
590598 num_indices_per_warp_ * num_warps_);
@@ -605,32 +613,40 @@ void ScatterWithDistributedIndices::ComputeIndexing(
605613 (block_x * num_warps_ + warp_id) % num_warps_per_slice_;
606614 auto lane_id = thread_x % warp_size_;
607615 auto index_id_loop = getAffineSymbolExpr (0 , ctx);
616+ auto index_vector_id = getAffineSymbolExpr (1 , ctx);
608617
609- auto index_id_expr = slice_id * num_indices_per_warp_ + index_id_loop;
610- std::pair<AffineExpr, Interval> index_id_constraint =
611- std::make_pair (index_id_expr, Interval{ 0 , description_. num_slices - 1 }) ;
618+ auto vectorized_index_id_expr = slice_id * num_indices_per_warp_ +
619+ index_id_loop * indices_vector_size_ +
620+ index_vector_id ;
612621
613622 auto grid_vars =
614623 DimVarsFromGPUGrid ({num_warps_ * warp_size_, 1 , 1 , num_blocks_, 1 , 1 });
615624 if (indices_map) {
616- auto index_dim_loop = getAffineSymbolExpr (1 , ctx);
625+ auto index_dim_loop = getAffineSymbolExpr (2 , ctx);
617626 *indices_map = IndexingMap{
618- AffineMap::get (6 , 2 , {index_id_expr , index_dim_loop}, ctx),
627+ AffineMap::get (6 , 3 , {vectorized_index_id_expr , index_dim_loop}, ctx),
619628 grid_vars,
620- {IndexingMap::Variable{{0 , num_indices_per_warp_ - 1 }, " index_id_loop" },
629+ {IndexingMap::Variable{
630+ {0 , num_indices_per_warp_ / indices_vector_size_ - 1 },
631+ " index_id_loop" },
632+ IndexingMap::Variable{{0 , indices_vector_size_ - 1 },
633+ " index_vector_id" },
621634 IndexingMap::Variable{{0 , description_.index_vector_length - 1 },
622635 " index_dim" }},
623636 /* rt_vars=*/ {},
624- {index_id_constraint}};
637+ {std::make_pair (vectorized_index_id_expr,
638+ Interval{0 , description_.num_slices - 1 })}};
625639
626640 indices_map->Simplify ();
627641 }
628642
629643 if (updates_map) {
644+ auto index_id = getAffineSymbolExpr (0 , ctx);
630645 auto update_dim_loop = getAffineSymbolExpr (1 , ctx);
631646 auto vector_id = getAffineSymbolExpr (2 , ctx);
632647 auto num_elements_per_slice = Product (description_.slice_shape );
633648
649+ auto index_id_expr = slice_id * num_indices_per_warp_ + index_id;
634650 auto linear_slice_index =
635651 warp_id_in_slice * warp_size_ * vector_size_ +
636652 update_dim_loop * vector_size_ * warp_size_ * num_warps_per_slice_ +
@@ -652,7 +668,8 @@ void ScatterWithDistributedIndices::ComputeIndexing(
652668 IndexingMap::Variable{{0 , vector_size_ - 1 }, " vector_id" }},
653669 /* rt_vars=*/ {},
654670 std::vector<std::pair<AffineExpr, Interval>>{
655- index_id_constraint,
671+ std::make_pair (index_id_expr,
672+ Interval{0 , description_.num_slices - 1 }),
656673 std::make_pair (linear_slice_index,
657674 Interval{0 , num_elements_per_slice - 1 })}};
658675
@@ -695,11 +712,12 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl(
695712 ValueRange thread_and_block_ids, Value output_tensor) const {
696713 if (VLOG_IS_ON (5 )) {
697714 llvm::errs () << " Settings for ScatterWithDistributedIndices: \n "
698- << " vector_size_: " << vector_size_ << " \n "
699- << " num_warps_: " << num_warps_ << " \n "
700- << " num_blocks_: " << num_blocks_
701- << " num_warps_per_slice_: " << num_warps_per_slice_ << " \n "
702- << " num_indices_per_warp_: " << num_indices_per_warp_;
715+ << " vector_size: " << vector_size_ << " \n "
716+ << " num_warps: " << num_warps_ << " \n "
717+ << " num_blocks: " << num_blocks_ << " \n "
718+ << " num_warps_per_slice: " << num_warps_per_slice_ << " \n "
719+ << " num_indices_per_warp: " << num_indices_per_warp_ << " \n "
720+ << " indices_vector_size: " << indices_vector_size_ << " \n " ;
703721 }
704722 if (num_indices_per_warp_ == 1 ) {
705723 EmitNaiveImplementation (b, description_, helper, updates_map, indices_map,
@@ -709,12 +727,17 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl(
709727 MLIRContext* mlir_context = b.getContext ();
710728
711729 auto thread_id_to_update_id_map = IndexingMap (
712- AffineMap::get (6 , 1 , {updates_map .GetAffineMap ().getResult (0 )},
730+ AffineMap::get (6 , 2 , {indices_map .GetAffineMap ().getResult (0 )},
713731 mlir_context),
714- updates_map.GetDimVars (),
715- /* range_vars = */ {updates_map.GetRangeVars ().front ()},
716- /* rt vars = */ {});
717- IndexingMap slice_indexing = ConvertRangeVariableToDimension (updates_map, 0 );
732+ indices_map.GetDimVars (),
733+ /* range_vars = */
734+ {indices_map.GetRangeVars ().begin (),
735+ indices_map.GetRangeVars ().begin () + 2 },
736+ /* rt vars = */ {}, indices_map.GetConstraints ());
737+
738+ // Convert index_id_loop and index_vector_id to dimension variables.
739+ IndexingMap slice_indexing =
740+ ConvertRangeVariableToDimension (updates_map, {0 });
718741
719742 // Prepare loop initial values. Inits are packed as
720743 // [index_changed, is_inbounds, index_0, ..., accumulator].
@@ -740,7 +763,14 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl(
740763 Value iter_is_inbounds = iter_args_unpack[2 ].front ();
741764 Value iter_acc = iter_args_unpack[3 ].front ();
742765 Value iter_output = iter_args_unpack[4 ].front ();
743- Value iter_slice_id = ivs.front ();
766+ CHECK_EQ (ivs.size (), 2 );
767+ Value index_loop_id = ivs.front ();
768+ Value index_vector_id = ivs.back ();
769+ Value iter_slice_id = nested_b.create <arith::AddIOp>(
770+ nested_b.create <arith::MulIOp>(
771+ index_loop_id,
772+ nested_b.create <arith::ConstantIndexOp>(indices_vector_size_)),
773+ index_vector_id);
744774
745775 SmallVector<Value> offsets =
746776 PadWithZeros (trimmed_offsets, output_rank, nested_b);
@@ -926,32 +956,44 @@ std::unique_ptr<ScatterFusion> CreateScatterFusion(
926956
927957 int64_t max_active_warps =
928958 kNumWarpsPerBlock * analysis.device_info ().core_count ();
929- // For sorted scatter, we try to estimate the number of updates per warp by
930- // computing the ratio of the number of the given updates to the number of the
931- // possible valid indices. If we do not have multiple updates per warp, there
932- // is no reason to use this algorithm.
933959 // TODO(b/385081952): Investigate why bf16 and f64 leads to incorrect results.
934- if (description.scatter ->indices_are_sorted () &&
935- description.elem_type != BF16 && num_slices > 2 * max_active_warps) {
936- int64_t num_indices_per_warp = CeilOfRatio (
937- num_slices, GetNumPossibleValidIndices (
938- description.slice_shape , description.output_shape ,
939- description.index_vector_length ));
940- int64_t num_warps_per_slice = 1 ;
941- if (num_indices_per_warp > 2 &&
942- num_active_threads_per_warp > warp_size / 2 ) {
943- return std::make_unique<ScatterWithDistributedIndices>(
944- analysis, description, vector_size, num_warps_per_slice,
945- num_indices_per_warp);
946- }
947- }
948960 // If we have enough data, we assign each warp to process a single
949961 // slice.
950962 if (num_slices > max_active_warps &&
951963 num_active_threads_per_warp > warp_size / 2 ) {
964+ int64_t num_indices_per_warp = 1 ;
965+ int64_t indices_vector_size = 1 ;
966+ int64_t num_warps_per_slice = 1 ;
967+ // For sorted scatter, we try to estimate the number of updates per warp by
968+ // computing the ratio of the number of the given updates to the number of
969+ // the possible valid indices. If we do not have multiple updates per warp,
970+ // there is no reason to use this algorithm.
971+ if (description.scatter ->indices_are_sorted ()) {
972+ num_indices_per_warp = CeilOfRatio (
973+ num_slices,
974+ std::max (max_active_warps,
975+ GetNumPossibleValidIndices (
976+ description.slice_shape , description.output_shape ,
977+ description.index_vector_length )));
978+
979+ // If the index_vector_length is 1, we can vectorize the indices read.
980+ int64_t index_elem_type_bits = primitive_util::BitWidth (
981+ description.scatter ->scatter_indices ()->shape ().element_type ());
982+ int64_t max_vectorized_indices =
983+ kMaxVectorizedBits / index_elem_type_bits;
984+ if (description.index_vector_length == 1 &&
985+ num_indices_per_warp > max_vectorized_indices) {
986+ // Pad num_indices_per_warp to the next multiple of
987+ // max_vectorized_indices.
988+ num_indices_per_warp =
989+ CeilOfRatio (num_indices_per_warp, max_vectorized_indices) *
990+ max_vectorized_indices;
991+ indices_vector_size = max_vectorized_indices;
992+ }
993+ }
952994 return std::make_unique<ScatterWithDistributedIndices>(
953- analysis, description, vector_size,
954- /* num_warps_per_slice= */ 1 , /* num_indices_per_warp= */ 1 );
995+ analysis, description, vector_size, num_warps_per_slice,
996+ num_indices_per_warp, indices_vector_size );
955997 }
956998 // Otherwise, we distribute the linearized updates tensor.
957999 vector_size =
0 commit comments