Skip to content

Commit 8a2f201

Browse files
authored
[GPU] Add vector distribution pattern for map_scatter (#21124)
Adds a vector distribution pattern for `iree_linalg_ext.map_scatter`. The implementation is similar to that of vector.transfer_write without masking, and the main difference is in how the distributed offsets are handled by the distributed op. This distribution will be used after the map_scatter op is vectorized, but before it is decomposed. This keeps the distribution pattern simple, because only the input vector needs to be distributed, and the index mapping to the distributed space is very simple. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 643382b commit 8a2f201

File tree

2 files changed

+178
-12
lines changed

2 files changed

+178
-12
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp

Lines changed: 115 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
1111
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1212
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
13+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1314
#include "iree/compiler/Utils/Permutation.h"
1415
#include "llvm/ADT/ArrayRef.h"
1516
#include "llvm/ADT/SmallVector.h"
@@ -282,6 +283,17 @@ static VectorValue projectVector(RewriterBase &rewriter, Location loc,
282283
return cast<VectorValue>(sliced.getResult());
283284
}
284285

286+
static VectorValue extractSliceAsVector(RewriterBase &rewriter, Location loc,
287+
Value src, ArrayRef<int64_t> offsets) {
288+
Value slice = rewriter.create<vector::ExtractOp>(loc, src, offsets);
289+
// Promote the slicedVector to 0-d vector if it is a scalar.
290+
if (!isa<VectorType>(slice.getType())) {
291+
auto promotedType = VectorType::get({}, getElementTypeOrSelf(slice));
292+
slice = rewriter.create<vector::BroadcastOp>(loc, promotedType, slice);
293+
}
294+
return cast<VectorValue>(slice);
295+
}
296+
285297
namespace {
286298

287299
/// Pattern to distribute `vector.transfer_read` ops with nested layouts.
@@ -476,16 +488,9 @@ struct DistributeTransferWrite final
476488
// dimensions are either unrolled or distributed such that this is a
477489
// contiguous slice.
478490
ArrayRef<int64_t> offsetArray(offsets);
479-
Value slicedVector = rewriter.create<vector::ExtractOp>(
480-
writeOp.getLoc(), distributedVector,
481-
offsetArray.take_front(rank * 2));
482-
// Promote the slicedVector to 0-d vector if it is a scalar.
483-
if (!isa<VectorType>(slicedVector.getType())) {
484-
auto promotedType =
485-
VectorType::get({}, getElementTypeOrSelf(slicedVector));
486-
slicedVector = rewriter.create<vector::BroadcastOp>(
487-
writeOp.getLoc(), promotedType, slicedVector);
488-
}
491+
VectorValue slicedVector =
492+
extractSliceAsVector(rewriter, writeOp.getLoc(), distributedVector,
493+
offsetArray.take_front(rank * 2));
489494

490495
VectorValue slicedMask = nullptr;
491496
if (mask) {
@@ -676,6 +681,104 @@ struct DistributeTransferGather final
676681
int64_t subgroupSize;
677682
};
678683

684+
/// Pattern to distribute `iree_linalg_ext.map_scatter` ops with nested layouts.
685+
/// Only the input is distributed, since the output is never a vector. The
686+
/// distribution of the input is similar to that of a vector.transfer_write.
687+
struct DistributeMapScatter final
688+
: OpDistributionPattern<IREE::LinalgExt::MapScatterOp> {
689+
using OpDistributionPattern::OpDistributionPattern;
690+
691+
DistributeMapScatter(MLIRContext *context, Value threadId,
692+
int64_t subgroupSize)
693+
: OpDistributionPattern(context), threadId(threadId),
694+
subgroupSize(subgroupSize) {}
695+
696+
LogicalResult matchAndRewrite(IREE::LinalgExt::MapScatterOp mapScatterOp,
697+
DistributionSignature &signature,
698+
PatternRewriter &rewriter) const override {
699+
auto input = dyn_cast<VectorValue>(mapScatterOp.getInput());
700+
if (!input) {
701+
return rewriter.notifyMatchFailure(mapScatterOp, "input is not a vector");
702+
}
703+
NestedLayoutAttr vectorLayout =
704+
dyn_cast<NestedLayoutAttr>(signature[input]);
705+
if (!vectorLayout) {
706+
return rewriter.notifyMatchFailure(mapScatterOp,
707+
"non-nested map_scatter layout");
708+
}
709+
if (!isa<MemRefType>(mapScatterOp.getOutput().getType())) {
710+
return rewriter.notifyMatchFailure(mapScatterOp,
711+
"distribution expects memrefs");
712+
}
713+
SmallVector<Value> warpIndices, threadIndices;
714+
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
715+
vectorLayout, warpIndices,
716+
threadIndices))) {
717+
return rewriter.notifyMatchFailure(
718+
mapScatterOp, "warp or thread tiles have overlapping strides");
719+
}
720+
721+
Value distributedVector = getDistributed(rewriter, input, vectorLayout);
722+
723+
Location loc = mapScatterOp.getLoc();
724+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
725+
SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
726+
SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
727+
for (auto [idx, offsets] :
728+
llvm::enumerate(StaticTileOffsetRange(distShape, tileShape))) {
729+
// Extract the "element vector" from the inner most dimensions. All outer
730+
// dimensions are either unrolled or distributed such that this is a
731+
// contiguous slice.
732+
ArrayRef<int64_t> offsetArray(offsets);
733+
VectorValue distributedInput = extractSliceAsVector(
734+
rewriter, loc, distributedVector,
735+
offsetArray.take_front(vectorLayout.getRank() * 2));
736+
737+
// Clone the map_scatter op with the "element vector" as the input, and
738+
// adjust the transformation region to account for the distributed
739+
// offsets.
740+
AffineMap permutationMap =
741+
rewriter.getMultiDimIdentityMap(input.getType().getRank());
742+
SmallVector<Value> indices(input.getType().getRank(), zero);
743+
SmallVector<Value> distributedOffsets =
744+
getTransferIndicesFromNestedLayout(rewriter, indices, offsets,
745+
vectorLayout, permutationMap,
746+
warpIndices, threadIndices);
747+
IREE::LinalgExt::MapScatterOp distributedMapScatter =
748+
clone(rewriter, mapScatterOp, mapScatterOp.getResultTypes(),
749+
{distributedInput, mapScatterOp.getOutput()});
750+
int64_t sliceRank = distributedInput.getType().getRank();
751+
int64_t rankDiff = input.getType().getRank() - sliceRank;
752+
// Add the distributed offsets in the map_scatter transformation body.
753+
auto transformationBuilder = [&](ArrayRef<BlockArgument> newIndices) {
754+
SmallVector<Value> replacementIndices(distributedOffsets);
755+
for (auto [i, replacementIdx] : llvm::enumerate(replacementIndices)) {
756+
// Rank-reduced dimensions can be directly replaced by the distributed
757+
// index, since their size is 1 in the new map_scatter input.
758+
if (i < rankDiff) {
759+
continue;
760+
}
761+
// Otherwise, the dimension is a contiguous element dimension, so
762+
// the mapping is achieved by adding the corresponding block argument
763+
// to the sliced index.
764+
BlockArgument newTransformationIdx = newIndices[i - rankDiff];
765+
replacementIdx = rewriter.create<arith::AddIOp>(
766+
loc, newTransformationIdx, replacementIdx);
767+
}
768+
return replacementIndices;
769+
};
770+
distributedMapScatter.insertTransformationAtStart(
771+
rewriter, transformationBuilder, sliceRank);
772+
}
773+
774+
rewriter.eraseOp(mapScatterOp);
775+
return success();
776+
}
777+
778+
Value threadId;
779+
int64_t subgroupSize;
780+
};
781+
679782
static VectorValue broadcastToShape(RewriterBase &rewriter, Value source,
680783
ArrayRef<int64_t> shape,
681784
ArrayRef<bool> broadcastedDims) {
@@ -2030,8 +2133,8 @@ void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
20302133
int64_t subgroupSize,
20312134
int64_t maxBitsPerShuffle) {
20322135
patterns.add<DistributeTransferRead, DistributeTransferWrite,
2033-
DistributeTransferGather>(patterns.getContext(), threadId,
2034-
subgroupSize);
2136+
DistributeTransferGather, DistributeMapScatter>(
2137+
patterns.getContext(), threadId, subgroupSize);
20352138
patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
20362139
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
20372140
maxBitsPerShuffle);

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,3 +1346,66 @@ builtin.module attributes { transform.with_named_sequence } {
13461346

13471347
// CHECK-LABEL: @paged_transfer_gather_multi_index
13481348
// CHECK-COUNT-4: vector_ext.transfer_gather
1349+
1350+
// -----
1351+
1352+
#layout_row_major = #iree_vector_ext.nested_layout<
1353+
subgroup_tile = [1, 1],
1354+
batch_tile = [2, 2],
1355+
outer_tile = [1, 1],
1356+
thread_tile = [8, 1],
1357+
element_tile = [1, 8],
1358+
1359+
subgroup_strides = [1, 1],
1360+
thread_strides = [1, 1]
1361+
>
1362+
1363+
func.func @distribute_map_scatter_row_major(%root: vector<16x16xf16>, %output: memref<64x64xf16>) {
1364+
%rootl = iree_vector_ext.to_layout %root to layout(#layout_row_major) : vector<16x16xf16>
1365+
iree_linalg_ext.map_scatter %rootl into %output {
1366+
^bb0(%idx0: index, %idx1: index):
1367+
%mask = arith.constant true
1368+
iree_linalg_ext.yield %idx0, %idx1, %mask : index, index, i1
1369+
} : vector<16x16xf16> into memref<64x64xf16>
1370+
func.return
1371+
}
1372+
1373+
builtin.module attributes { transform.with_named_sequence } {
1374+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
1375+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
1376+
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
1377+
transform.yield
1378+
}
1379+
}
1380+
1381+
// CHECK-LABEL: @distribute_map_scatter_row_major
1382+
// CHECK-DAG: %[[IDX:.+]] = gpu.thread_id x
1383+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
1384+
// CHECK-DAG: %[[LANEX:.+]]:2 = affine.delinearize_index %[[IDX]] into (8)
1385+
// CHECK-DAG: %[[SLICE0:.+]] = vector.extract %{{.*}}[0, 0, 0, 0]
1386+
// CHECK: iree_linalg_ext.map_scatter %[[SLICE0]]
1387+
// CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
1388+
// CHECK: %[[DISTRIBUTED_IDX0:.+]] = arith.addi %[[IDX0]], %[[LANEX]]#1
1389+
// CHECK: iree_linalg_ext.yield %[[DISTRIBUTED_IDX0]], %[[IDX1]]
1390+
// CHECK: : vector<1x8xf16> into memref<64x64xf16>
1391+
// CHECK: %[[SLICE1:.+]] = vector.extract %{{.*}}[0, 1, 0, 0]
1392+
// CHECK: iree_linalg_ext.map_scatter %[[SLICE1]]
1393+
// CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
1394+
// CHECK-DAG: %[[DISTRIBUTED_IDX0:.+]] = arith.addi %[[IDX0]], %[[LANEX]]#1
1395+
// CHECK-DAG: %[[DISTRIBUTED_IDX1:.+]] = arith.addi %[[IDX1]], %[[C8]]
1396+
// CHECK: iree_linalg_ext.yield %[[DISTRIBUTED_IDX0]], %[[DISTRIBUTED_IDX1]]
1397+
// CHECK: : vector<1x8xf16> into memref<64x64xf16>
1398+
// CHECK-DAG: %[[LANEX_PLUS_VECDIMX:.+]] = affine.linearize_index disjoint [%c1, %[[LANEX]]#1] by (2, 8)
1399+
// CHECK-DAG: %[[SLICE2:.+]] = vector.extract %{{.*}}[1, 0, 0, 0]
1400+
// CHECK: iree_linalg_ext.map_scatter %[[SLICE2]]
1401+
// CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
1402+
// CHECK: %[[DISTRIBUTED_IDX0:.+]] = arith.addi %[[IDX0]], %[[LANEX_PLUS_VECDIMX]]
1403+
// CHECK: iree_linalg_ext.yield %[[DISTRIBUTED_IDX0]], %[[IDX1]]
1404+
// CHECK: : vector<1x8xf16> into memref<64x64xf16>
1405+
// CHECK: %[[SLICE3:.+]] = vector.extract %{{.*}}[1, 1, 0, 0]
1406+
// CHECK: iree_linalg_ext.map_scatter %[[SLICE3]]
1407+
// CHECK: ^bb0(%[[IDX0:.+]]: index, %[[IDX1:.+]]: index):
1408+
// CHECK-DAG: %[[DISTRIBUTED_IDX0:.+]] = arith.addi %[[IDX0]], %[[LANEX_PLUS_VECDIMX]]
1409+
// CHECK-DAG: %[[DISTRIBUTED_IDX1:.+]] = arith.addi %[[IDX1]], %[[C8]]
1410+
// CHECK: iree_linalg_ext.yield %[[DISTRIBUTED_IDX0]], %[[DISTRIBUTED_IDX1]]
1411+
// CHECK: : vector<1x8xf16> into memref<64x64xf16>

0 commit comments

Comments
 (0)