Skip to content

Commit a32e667

Browse files
pifon2aGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Enable vectorization of the indices operand for scatter.
PiperOrigin-RevId: 715329926
1 parent 8a11e8e commit a32e667

File tree

3 files changed

+128
-78
lines changed

3 files changed

+128
-78
lines changed

xla/backends/gpu/codegen/emitters/scatter.cc

Lines changed: 108 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ using emitters::PartitionedComputations;
8181
using emitters::ProvideParameter;
8282
using llvm::APFloat;
8383
using llvm::APInt;
84+
using llvm::ArrayRef;
8485
using llvm::SmallVector;
8586
using mlir::AffineExpr;
8687
using mlir::AffineMap;
@@ -171,7 +172,7 @@ Value UpdateIsInbounds(ImplicitLocOpBuilder& b, Value is_inbounds,
171172
.front();
172173
}
173174

174-
SmallVector<Value> Pack(llvm::ArrayRef<ValueRange> ranges) {
175+
SmallVector<Value> Pack(ArrayRef<ValueRange> ranges) {
175176
int64_t total_size = 0;
176177
for (auto& range : ranges) {
177178
total_size += range.size();
@@ -184,8 +185,7 @@ SmallVector<Value> Pack(llvm::ArrayRef<ValueRange> ranges) {
184185
return result;
185186
}
186187

187-
SmallVector<ValueRange> Unpack(ValueRange range,
188-
llvm::ArrayRef<int64_t> sizes) {
188+
SmallVector<ValueRange> Unpack(ValueRange range, ArrayRef<int64_t> sizes) {
189189
int64_t total_size = 0;
190190
for (auto& size : sizes) {
191191
total_size += size;
@@ -213,29 +213,42 @@ SmallVector<Value, 4> PadWithZeros(ValueRange values, int64_t size,
213213
}
214214

215215
// Creates a new indexing map that is the same as `map` but with the range
216-
// variable at `range_var_index` replaced with the new dimension variable at
217-
// `dimension_{dim_var_size)`. Potentially, it can be moved to indexing_map.h.
218-
IndexingMap ConvertRangeVariableToDimension(const IndexingMap& map,
219-
int64_t range_var_index) {
216+
// variables at `range_var_indices` converted to the new dimensions variables at
217+
// and added to the end of dimension variables list. Potentially, it can be
218+
// moved to indexing_map.h.
219+
IndexingMap ConvertRangeVariableToDimension(
220+
const IndexingMap& map, ArrayRef<int64_t> range_var_indices) {
221+
CHECK(std::is_sorted(range_var_indices.begin(), range_var_indices.end()));
220222
auto* mlir_context = map.GetMLIRContext();
221223

222224
AffineMap affine_map = map.GetAffineMap();
223-
// Update the affine map.
225+
// Update the affine map and the variables.
226+
std::vector<IndexingMap::Variable> dims = map.GetDimVars();
227+
std::vector<IndexingMap::Variable> range_vars;
228+
std::vector<IndexingMap::Variable> rt_vars = map.GetRTVars();
224229
SmallVector<AffineExpr, 4> symbol_replacements;
225230
symbol_replacements.reserve(affine_map.getNumSymbols());
231+
int64_t range_var_count = 0;
232+
int64_t range_var_indices_count = range_var_indices.size();
226233
for (int i = 0; i < affine_map.getNumSymbols(); ++i) {
227-
if (i == range_var_index) {
234+
auto range_var = map.GetRangeVar(i);
235+
if (range_var_count < range_var_indices_count &&
236+
i == range_var_indices[range_var_count]) {
228237
symbol_replacements.push_back(
229238
getAffineDimExpr(affine_map.getNumDims(), mlir_context));
239+
dims.push_back(range_var);
240+
range_var_count++;
230241
} else {
231242
symbol_replacements.push_back(
232-
getAffineSymbolExpr(i > range_var_index ? i - 1 : i, mlir_context));
243+
getAffineSymbolExpr(i - range_var_count, mlir_context));
244+
range_vars.push_back(range_var);
233245
}
234246
}
235247

236248
AffineMap converted_affine_map = affine_map.replaceDimsAndSymbols(
237-
{}, symbol_replacements, affine_map.getNumDims() + 1,
238-
affine_map.getNumSymbols() - 1);
249+
{}, symbol_replacements,
250+
affine_map.getNumDims() + range_var_indices_count,
251+
affine_map.getNumSymbols() - range_var_indices_count);
239252

240253
// Update the constraints.
241254
std::vector<std::pair<AffineExpr, Interval>> constraints;
@@ -244,13 +257,6 @@ IndexingMap ConvertRangeVariableToDimension(const IndexingMap& map,
244257
constraints.push_back({constraint.first.replaceSymbols(symbol_replacements),
245258
constraint.second});
246259
}
247-
// Update the variables.
248-
std::vector<IndexingMap::Variable> dims = map.GetDimVars();
249-
std::vector<IndexingMap::Variable> range_vars = map.GetRangeVars();
250-
std::vector<IndexingMap::Variable> rt_vars = map.GetRTVars();
251-
252-
dims.push_back(range_vars[range_var_index]);
253-
range_vars.erase(range_vars.begin() + range_var_index);
254260
return IndexingMap{converted_affine_map, std::move(dims),
255261
std::move(range_vars), std::move(rt_vars), constraints};
256262
}
@@ -299,7 +305,8 @@ class EmitterHelper {
299305

300306
Value WriteAccumulatorToOutput(ImplicitLocOpBuilder& b,
301307
Value write_to_output_required,
302-
ValueRange thread_and_block_ids, Value iv,
308+
ValueRange thread_and_block_ids,
309+
Value index_id,
303310
const IndexingMap& slice_indexing,
304311
ValueRange offsets, Value accumulator,
305312
Value output_tensor) const;
@@ -370,10 +377,10 @@ SmallVector<Value> EmitterHelper::WriteAccumulatedElementToOutput(
370377

371378
Value EmitterHelper::WriteAccumulatorToOutput(
372379
ImplicitLocOpBuilder& b, Value write_to_output_required,
373-
ValueRange thread_and_block_ids, Value iv,
380+
ValueRange thread_and_block_ids, Value index_id,
374381
const IndexingMap& slice_indexing, ValueRange offsets, Value accumulator,
375382
Value output_tensor) const {
376-
SmallVector<Value> dims = Pack({thread_and_block_ids, iv});
383+
SmallVector<Value> dims = Pack({thread_and_block_ids, index_id});
377384
return EmitUpdateIf(
378385
b, write_to_output_required, output_tensor,
379386
[&](ImplicitLocOpBuilder& if_builder) -> SmallVector<Value> {
@@ -569,9 +576,9 @@ absl::Status ScatterWithDistributedUpdates::EmitEntryFunctionImpl(
569576
ValueRange thread_and_block_ids, Value output_tensor) const {
570577
if (VLOG_IS_ON(5)) {
571578
llvm::errs() << "Settings for ScatterWithDistributedUpdates: \n"
572-
<< "vector_size_: " << vector_size_ << "\n"
573-
<< "num_warps_: " << num_warps_ << "\n"
574-
<< "num_blocks_: " << num_blocks_;
579+
<< "vector_size: " << vector_size_ << "\n"
580+
<< "num_warps: " << num_warps_ << "\n"
581+
<< "num_blocks: " << num_blocks_ << "\n";
575582
}
576583
EmitNaiveImplementation(b, description_, helper, updates_map, indices_map,
577584
thread_and_block_ids, output_tensor);
@@ -581,10 +588,11 @@ absl::Status ScatterWithDistributedUpdates::EmitEntryFunctionImpl(
581588
ScatterWithDistributedIndices::ScatterWithDistributedIndices(
582589
const HloFusionAnalysis& analysis, const ScatterDescription& description,
583590
int64_t vector_size, int64_t num_warps_per_slice,
584-
int64_t num_indices_per_warp)
591+
int64_t num_indices_per_warp, int64_t indices_vector_size)
585592
: ScatterFusion(analysis, description, vector_size),
586593
num_warps_per_slice_(num_warps_per_slice),
587-
num_indices_per_warp_(num_indices_per_warp) {
594+
num_indices_per_warp_(num_indices_per_warp),
595+
indices_vector_size_(indices_vector_size) {
588596
num_warps_ = kNumWarpsPerBlock;
589597
num_blocks_ = CeilOfRatio(description.num_slices * num_warps_per_slice_,
590598
num_indices_per_warp_ * num_warps_);
@@ -605,32 +613,40 @@ void ScatterWithDistributedIndices::ComputeIndexing(
605613
(block_x * num_warps_ + warp_id) % num_warps_per_slice_;
606614
auto lane_id = thread_x % warp_size_;
607615
auto index_id_loop = getAffineSymbolExpr(0, ctx);
616+
auto index_vector_id = getAffineSymbolExpr(1, ctx);
608617

609-
auto index_id_expr = slice_id * num_indices_per_warp_ + index_id_loop;
610-
std::pair<AffineExpr, Interval> index_id_constraint =
611-
std::make_pair(index_id_expr, Interval{0, description_.num_slices - 1});
618+
auto vectorized_index_id_expr = slice_id * num_indices_per_warp_ +
619+
index_id_loop * indices_vector_size_ +
620+
index_vector_id;
612621

613622
auto grid_vars =
614623
DimVarsFromGPUGrid({num_warps_ * warp_size_, 1, 1, num_blocks_, 1, 1});
615624
if (indices_map) {
616-
auto index_dim_loop = getAffineSymbolExpr(1, ctx);
625+
auto index_dim_loop = getAffineSymbolExpr(2, ctx);
617626
*indices_map = IndexingMap{
618-
AffineMap::get(6, 2, {index_id_expr, index_dim_loop}, ctx),
627+
AffineMap::get(6, 3, {vectorized_index_id_expr, index_dim_loop}, ctx),
619628
grid_vars,
620-
{IndexingMap::Variable{{0, num_indices_per_warp_ - 1}, "index_id_loop"},
629+
{IndexingMap::Variable{
630+
{0, num_indices_per_warp_ / indices_vector_size_ - 1},
631+
"index_id_loop"},
632+
IndexingMap::Variable{{0, indices_vector_size_ - 1},
633+
"index_vector_id"},
621634
IndexingMap::Variable{{0, description_.index_vector_length - 1},
622635
"index_dim"}},
623636
/*rt_vars=*/{},
624-
{index_id_constraint}};
637+
{std::make_pair(vectorized_index_id_expr,
638+
Interval{0, description_.num_slices - 1})}};
625639

626640
indices_map->Simplify();
627641
}
628642

629643
if (updates_map) {
644+
auto index_id = getAffineSymbolExpr(0, ctx);
630645
auto update_dim_loop = getAffineSymbolExpr(1, ctx);
631646
auto vector_id = getAffineSymbolExpr(2, ctx);
632647
auto num_elements_per_slice = Product(description_.slice_shape);
633648

649+
auto index_id_expr = slice_id * num_indices_per_warp_ + index_id;
634650
auto linear_slice_index =
635651
warp_id_in_slice * warp_size_ * vector_size_ +
636652
update_dim_loop * vector_size_ * warp_size_ * num_warps_per_slice_ +
@@ -652,7 +668,8 @@ void ScatterWithDistributedIndices::ComputeIndexing(
652668
IndexingMap::Variable{{0, vector_size_ - 1}, "vector_id"}},
653669
/*rt_vars=*/{},
654670
std::vector<std::pair<AffineExpr, Interval>>{
655-
index_id_constraint,
671+
std::make_pair(index_id_expr,
672+
Interval{0, description_.num_slices - 1}),
656673
std::make_pair(linear_slice_index,
657674
Interval{0, num_elements_per_slice - 1})}};
658675

@@ -695,11 +712,12 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl(
695712
ValueRange thread_and_block_ids, Value output_tensor) const {
696713
if (VLOG_IS_ON(5)) {
697714
llvm::errs() << "Settings for ScatterWithDistributedIndices: \n"
698-
<< "vector_size_: " << vector_size_ << "\n"
699-
<< "num_warps_: " << num_warps_ << "\n"
700-
<< "num_blocks_: " << num_blocks_
701-
<< "num_warps_per_slice_: " << num_warps_per_slice_ << "\n"
702-
<< "num_indices_per_warp_: " << num_indices_per_warp_;
715+
<< "vector_size: " << vector_size_ << "\n"
716+
<< "num_warps: " << num_warps_ << "\n"
717+
<< "num_blocks: " << num_blocks_ << "\n"
718+
<< "num_warps_per_slice: " << num_warps_per_slice_ << "\n"
719+
<< "num_indices_per_warp: " << num_indices_per_warp_ << "\n"
720+
<< "indices_vector_size: " << indices_vector_size_ << "\n";
703721
}
704722
if (num_indices_per_warp_ == 1) {
705723
EmitNaiveImplementation(b, description_, helper, updates_map, indices_map,
@@ -709,12 +727,17 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl(
709727
MLIRContext* mlir_context = b.getContext();
710728

711729
auto thread_id_to_update_id_map = IndexingMap(
712-
AffineMap::get(6, 1, {updates_map.GetAffineMap().getResult(0)},
730+
AffineMap::get(6, 2, {indices_map.GetAffineMap().getResult(0)},
713731
mlir_context),
714-
updates_map.GetDimVars(),
715-
/*range_vars = */ {updates_map.GetRangeVars().front()},
716-
/*rt vars = */ {});
717-
IndexingMap slice_indexing = ConvertRangeVariableToDimension(updates_map, 0);
732+
indices_map.GetDimVars(),
733+
/*range_vars = */
734+
{indices_map.GetRangeVars().begin(),
735+
indices_map.GetRangeVars().begin() + 2},
736+
/*rt vars = */ {}, indices_map.GetConstraints());
737+
738+
// Convert index_id_loop and index_vector_id to dimension variables.
739+
IndexingMap slice_indexing =
740+
ConvertRangeVariableToDimension(updates_map, {0});
718741

719742
// Prepare loop initial values. Inits are packed as
720743
// [index_changed, is_inbounds, index_0, ..., accumulator].
@@ -740,7 +763,14 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl(
740763
Value iter_is_inbounds = iter_args_unpack[2].front();
741764
Value iter_acc = iter_args_unpack[3].front();
742765
Value iter_output = iter_args_unpack[4].front();
743-
Value iter_slice_id = ivs.front();
766+
CHECK_EQ(ivs.size(), 2);
767+
Value index_loop_id = ivs.front();
768+
Value index_vector_id = ivs.back();
769+
Value iter_slice_id = nested_b.create<arith::AddIOp>(
770+
nested_b.create<arith::MulIOp>(
771+
index_loop_id,
772+
nested_b.create<arith::ConstantIndexOp>(indices_vector_size_)),
773+
index_vector_id);
744774

745775
SmallVector<Value> offsets =
746776
PadWithZeros(trimmed_offsets, output_rank, nested_b);
@@ -926,32 +956,44 @@ std::unique_ptr<ScatterFusion> CreateScatterFusion(
926956

927957
int64_t max_active_warps =
928958
kNumWarpsPerBlock * analysis.device_info().core_count();
929-
// For sorted scatter, we try to estimate the number of updates per warp by
930-
// computing the ratio of the number of the given updates to the number of the
931-
// possible valid indices. If we do not have multiple updates per warp, there
932-
// is no reason to use this algorithm.
933959
// TODO(b/385081952): Investigate why bf16 and f64 leads to incorrect results.
934-
if (description.scatter->indices_are_sorted() &&
935-
description.elem_type != BF16 && num_slices > 2 * max_active_warps) {
936-
int64_t num_indices_per_warp = CeilOfRatio(
937-
num_slices, GetNumPossibleValidIndices(
938-
description.slice_shape, description.output_shape,
939-
description.index_vector_length));
940-
int64_t num_warps_per_slice = 1;
941-
if (num_indices_per_warp > 2 &&
942-
num_active_threads_per_warp > warp_size / 2) {
943-
return std::make_unique<ScatterWithDistributedIndices>(
944-
analysis, description, vector_size, num_warps_per_slice,
945-
num_indices_per_warp);
946-
}
947-
}
948960
// If we have enough data, we assign each warp to process a single
949961
// slice.
950962
if (num_slices > max_active_warps &&
951963
num_active_threads_per_warp > warp_size / 2) {
964+
int64_t num_indices_per_warp = 1;
965+
int64_t indices_vector_size = 1;
966+
int64_t num_warps_per_slice = 1;
967+
// For sorted scatter, we try to estimate the number of updates per warp by
968+
// computing the ratio of the number of the given updates to the number of
969+
// the possible valid indices. If we do not have multiple updates per warp,
970+
// there is no reason to use this algorithm.
971+
if (description.scatter->indices_are_sorted()) {
972+
num_indices_per_warp = CeilOfRatio(
973+
num_slices,
974+
std::max(max_active_warps,
975+
GetNumPossibleValidIndices(
976+
description.slice_shape, description.output_shape,
977+
description.index_vector_length)));
978+
979+
// If the index_vector_length is 1, we can vectorize the indices read.
980+
int64_t index_elem_type_bits = primitive_util::BitWidth(
981+
description.scatter->scatter_indices()->shape().element_type());
982+
int64_t max_vectorized_indices =
983+
kMaxVectorizedBits / index_elem_type_bits;
984+
if (description.index_vector_length == 1 &&
985+
num_indices_per_warp > max_vectorized_indices) {
986+
// Pad num_indices_per_warp to the next multiple of
987+
// max_vectorized_indices.
988+
num_indices_per_warp =
989+
CeilOfRatio(num_indices_per_warp, max_vectorized_indices) *
990+
max_vectorized_indices;
991+
indices_vector_size = max_vectorized_indices;
992+
}
993+
}
952994
return std::make_unique<ScatterWithDistributedIndices>(
953-
analysis, description, vector_size,
954-
/*num_warps_per_slice=*/1, /*num_indices_per_warp=*/1);
995+
analysis, description, vector_size, num_warps_per_slice,
996+
num_indices_per_warp, indices_vector_size);
955997
}
956998
// Otherwise, we distribute the linearized updates tensor.
957999
vector_size =

xla/backends/gpu/codegen/emitters/scatter.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,26 +147,29 @@ class ScatterWithDistributedUpdates : public ScatterFusion {
147147
%acc = vector<num_iter x vector_size>
148148
149149
// #indices_map
150-
%updated_accumulator, %updated_out = for %i = 0 to %num_indices_per_warp_ {
151-
%new_indices = PadWithZeros(ExtractOffsets(%indices_operand, %i))
150+
%updated_accumulator, %updated_out
151+
= for %i = 0 to %num_indices_per_warp_ step %indices_vector_size_ {
152+
for %j = 0 to %indices_vector_size_ step 1 {
153+
%index = %i * %indices_vector_size_ + %j + %index_start(bl_x, th_x)
154+
%new_indices = PadWithZeros(ExtractOffsets(%indices_operand, %index))
152155
%indices_changed = EmitInequalityCheck(%new_indices, %indices)
153-
if (%indices_changed && %i != 0) {
156+
if (%indices_changed && %index != 0) {
154157
%output_tensor = WriteAccumulatorToOutput(%current_acc, %current_out);
155158
}
156159
if (%indices_changed) {
157160
%inbounds = EmitBoundsCheck(%new_indices, %slice_shape, %output_shape)
158161
}
159162
if (%inbounds) {
160163
if (%indices_changed) {
161-
// updates_map(%i)
164+
// updates_map(%index)
162165
for %j = 0 to %num_slice_iterations_per_warp step 1 {
163166
for %k = 0 to %vector_size step 1 {
164167
%update_elem = GetUpdateElement
165168
%acc = %update_elem
166169
}
167170
}
168171
} else {
169-
// updates_map(%i)
172+
// updates_map(%index)
170173
for %j = 0 to %num_slice_iterations_per_warp step 1 {
171174
for %k = 0 to %vector_size step 1 {
172175
%update_elem = GetUpdateElement
@@ -184,7 +187,8 @@ class ScatterWithDistributedIndices : public ScatterFusion {
184187
const ScatterDescription& description,
185188
int64_t vector_size,
186189
int64_t num_warps_per_slice,
187-
int64_t num_indices_per_warp);
190+
int64_t num_indices_per_warp,
191+
int64_t indices_vector_size);
188192

189193
protected:
190194
void ComputeIndexing(mlir::MLIRContext* ctx, IndexingMap* updates_map,
@@ -206,6 +210,8 @@ class ScatterWithDistributedIndices : public ScatterFusion {
206210
// The number of indices that every warp iterates over. This is a useful
207211
// setting, if we know that the indices tensor is sorted.
208212
int64_t num_indices_per_warp_;
213+
// Vector size for the indices operand.
214+
int64_t indices_vector_size_;
209215
};
210216

211217
std::unique_ptr<ScatterFusion> CreateScatterFusion(

0 commit comments

Comments
 (0)