Skip to content

Commit 3774bea

Browse files
authored
[VectorExt] Add canonicalizations for iree_vector_ext.transfer_gather (iree-org#20454)
This patch adds canonicalizations to fold index_vecs into the transfer_gather operation.
1 parent d96dd22 commit 3774bea

File tree

4 files changed

+335
-4
lines changed

4 files changed

+335
-4
lines changed

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtBase.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def IREEVectorExt_Dialect : Dialect {
3232
let extraClassDeclaration = [{
3333
void registerAttributes();
3434
}];
35+
let dependentDialects = [
36+
"affine::AffineDialect",
37+
"vector::VectorDialect"
38+
];
3539
}
3640

3741
//===---------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <numeric>
99
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
1010
#include "llvm/ADT/TypeSwitch.h"
11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
12+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1113
#include "mlir/IR/DialectImplementation.h"
1214

1315
using namespace mlir;

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp

Lines changed: 214 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
88
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Utils/IndexingUtils.h"
911
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1012

1113
using namespace mlir;
@@ -38,9 +40,9 @@ OpFoldResult ToSIMTOp::fold(FoldAdaptor) {
3840
return {};
3941
}
4042

41-
//
43+
//===----------------------------------------------------------------------===//
4244
// TransferGatherOp
43-
//
45+
//===----------------------------------------------------------------------===//
4446

4547
Speculation::Speculatability TransferGatherOp::getSpeculatability() {
4648
if (isa<RankedTensorType>(getSource().getType())) {
@@ -438,13 +440,221 @@ ParseResult TransferGatherOp::parse(OpAsmParser &parser,
438440
return parser.addTypeToList(vectorType, result.types);
439441
}
440442

441-
void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
442-
MLIRContext *ctx) {}
443+
static int64_t getVectorRank(Type type) {
444+
if (auto vecType = dyn_cast<VectorType>(type)) {
445+
return vecType.getRank();
446+
}
447+
return 0;
448+
}
449+
450+
struct IndexVecFoldResult {
451+
Value indexVec;
452+
AffineMap indexMap;
453+
bool changed;
454+
};
455+
456+
static Value foldTransferGatherIndexVecs(
457+
TransferGatherOp gatherOp,
458+
function_ref<IndexVecFoldResult(Value, AffineMap, int64_t)>
459+
indexVecFolder) {
460+
bool changed = false;
461+
SmallVector<Value> newIndexVecs;
462+
SmallVector<AffineMap> newIndexedMaps;
463+
SmallVector<bool> indexed(gatherOp.getIndexed().getAsValueRange<BoolAttr>());
464+
int64_t currIndexVec = 0;
465+
for (auto i : llvm::seq<int64_t>(gatherOp.getIndices().size())) {
466+
if (!indexed[i]) {
467+
continue;
468+
}
469+
Value operand = gatherOp.getIndexVecs()[currIndexVec];
470+
AffineMap map = gatherOp.getIndexedMapsArray()[currIndexVec];
471+
++currIndexVec;
472+
473+
auto [indexVec, indexMap, vecChanged] = indexVecFolder(operand, map, i);
474+
changed |= vecChanged;
475+
476+
if (indexVec) {
477+
newIndexVecs.push_back(indexVec);
478+
newIndexedMaps.push_back(indexMap);
479+
indexed[i] = true;
480+
} else {
481+
indexed[i] = false;
482+
}
483+
}
484+
485+
if (!changed) {
486+
return Value();
487+
}
488+
489+
OpBuilder b(gatherOp);
490+
491+
SmallVector<Value> operands;
492+
SmallVector<int32_t> operandSegmentSizes;
493+
494+
// Source.
495+
operands.push_back(gatherOp.getSource());
496+
operandSegmentSizes.push_back(1);
497+
// Indices.
498+
SmallVector<Value> indices = gatherOp.getIndices();
499+
operands.append(indices);
500+
operandSegmentSizes.push_back(indices.size());
501+
// IndexVecs.
502+
operands.append(newIndexVecs);
503+
operandSegmentSizes.push_back(newIndexVecs.size());
504+
// Padding.
505+
operands.push_back(gatherOp.getPadding());
506+
operandSegmentSizes.push_back(1);
507+
// Mask.
508+
if (gatherOp.getMask()) {
509+
operands.push_back(gatherOp.getMask());
510+
operandSegmentSizes.push_back(1);
511+
} else {
512+
operandSegmentSizes.push_back(0);
513+
}
514+
515+
gatherOp.setIndexedMapsAttr(b.getAffineMapArrayAttr(newIndexedMaps));
516+
gatherOp->setOperands(operands);
517+
gatherOp.setIndexedAttr(b.getBoolArrayAttr(indexed));
518+
gatherOp.getProperties().setOperandSegmentSizes(operandSegmentSizes);
519+
520+
return gatherOp.getResult();
521+
}
522+
523+
static Value foldTransferGatherFromBroadcast(TransferGatherOp gatherOp) {
524+
return foldTransferGatherIndexVecs(
525+
gatherOp,
526+
[](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult {
527+
auto broadcast = operand.getDefiningOp<vector::BroadcastOp>();
528+
if (!broadcast) {
529+
return {operand, map, false};
530+
}
531+
532+
int64_t sourceRank = getVectorRank(broadcast.getSourceType());
533+
int64_t operandRank = getVectorRank(broadcast.getResultVectorType());
534+
AffineMap newMap =
535+
map.getSliceMap(operandRank - sourceRank, sourceRank);
536+
return {broadcast.getSource(), newMap, true};
537+
});
538+
}
539+
540+
static Value foldTransferGatherFromTranspose(TransferGatherOp gatherOp) {
541+
return foldTransferGatherIndexVecs(
542+
gatherOp,
543+
[](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult {
544+
auto transpose = operand.getDefiningOp<vector::TransposeOp>();
545+
if (!transpose) {
546+
return {operand, map, false};
547+
}
548+
549+
AffineMap newMap =
550+
AffineMap::getPermutationMap(
551+
invertPermutationVector(transpose.getPermutation()),
552+
transpose.getContext())
553+
.compose(map);
554+
return {transpose.getVector(), newMap, true};
555+
});
556+
}
557+
558+
static Value foldTransferGatherFromStep(TransferGatherOp gatherOp) {
559+
return foldTransferGatherIndexVecs(
560+
gatherOp,
561+
[](Value operand, AffineMap map, int64_t index) -> IndexVecFoldResult {
562+
auto step = operand.getDefiningOp<vector::StepOp>();
563+
if (!step) {
564+
return {operand, map, false};
565+
}
566+
567+
assert(map.getNumResults() == 1);
568+
int64_t resultDim = cast<AffineDimExpr>(map.getResult(0)).getPosition();
569+
570+
// If the map is indexing along the memory dimension, and the vector is
571+
// contigious, this is a contigious load on this dimension.
572+
if (resultDim == index) {
573+
return {Value(), AffineMap(), true};
574+
}
575+
576+
return {operand, map, false};
577+
});
578+
}
443579

444580
OpFoldResult TransferGatherOp::fold(FoldAdaptor adaptor) {
581+
if (auto res = foldTransferGatherFromBroadcast(*this)) {
582+
return res;
583+
}
584+
if (auto res = foldTransferGatherFromTranspose(*this)) {
585+
return res;
586+
}
587+
if (auto res = foldTransferGatherFromStep(*this)) {
588+
return res;
589+
}
445590
return OpFoldResult();
446591
}
447592

593+
struct FoldSingleElementIndexVec final : OpRewritePattern<TransferGatherOp> {
594+
using OpRewritePattern::OpRewritePattern;
595+
596+
LogicalResult matchAndRewrite(TransferGatherOp xferOp,
597+
PatternRewriter &rewriter) const override {
598+
599+
auto indexVecFolder = [&](Value indexVec, AffineMap map,
600+
int64_t index) -> IndexVecFoldResult {
601+
auto vectorTy = cast<VectorType>(indexVec.getType());
602+
if (vectorTy.getNumElements() != 1) {
603+
return {indexVec, map, false};
604+
}
605+
606+
// Extract the scalar and add it to the
607+
// corressponding base.
608+
OpOperand &base = xferOp.getIndicesMutable()[index];
609+
Value extracted = rewriter.create<vector::ExtractOp>(
610+
xferOp.getLoc(), indexVec,
611+
SmallVector<int64_t>(vectorTy.getRank(), 0));
612+
AffineExpr d0, d1;
613+
bindDims(xferOp.getContext(), d0, d1);
614+
Value newIndex = affine::makeComposedAffineApply(
615+
rewriter, xferOp.getLoc(), d0 + d1,
616+
ArrayRef<OpFoldResult>{base.get(), extracted})
617+
.getResult();
618+
base.set(newIndex);
619+
620+
return {Value(), AffineMap(), true};
621+
};
622+
623+
Value newVal = foldTransferGatherIndexVecs(xferOp, indexVecFolder);
624+
625+
if (!newVal) {
626+
return failure();
627+
}
628+
629+
return success();
630+
}
631+
};
632+
633+
struct FoldContigousGatherToTransferRead final
634+
: OpRewritePattern<TransferGatherOp> {
635+
using OpRewritePattern::OpRewritePattern;
636+
637+
LogicalResult matchAndRewrite(TransferGatherOp xferOp,
638+
PatternRewriter &rewriter) const override {
639+
if (!xferOp.getIndexVecs().empty()) {
640+
return failure();
641+
}
642+
643+
// Canonicalize to vector.transfer_read.
644+
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
645+
xferOp, xferOp.getVectorType(), xferOp.getSource(), xferOp.getIndices(),
646+
xferOp.getPermutationMap(), xferOp.getPadding(), xferOp.getMask(),
647+
xferOp.getInBounds());
648+
return success();
649+
};
650+
};
651+
652+
void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
653+
MLIRContext *ctx) {
654+
results.add<FoldSingleElementIndexVec, FoldContigousGatherToTransferRead>(
655+
ctx);
656+
}
657+
448658
// clang-format off
449659
#define GET_OP_CLASSES
450660
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" // IWYU pragma: keep

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/test/canonicalize.mlir

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,118 @@ func.func @to_simd_to_simt_multi_use(%simt: vector<4x4x4xf32>) -> (vector<4x4x4x
4747
}
4848

4949
// -----
50+
51+
func.func @transfer_gather_fold_broadcast(%indices: vector<64xindex>,
52+
%source: tensor<4096x64xf16>)
53+
-> vector<64x32xf16> {
54+
55+
%cst0 = arith.constant 0.0 : f16
56+
%c0 = arith.constant 0 : index
57+
58+
%broadcasted = vector.broadcast %indices : vector<64xindex> to vector<32x64xindex>
59+
60+
%out = iree_vector_ext.transfer_gather %source[%c0, %c0]
61+
[None, %broadcasted: vector<32x64xindex>], %cst0
62+
{ indexed_maps = [affine_map<(d0, d1) -> (d1, d0)>]}
63+
: tensor<4096x64xf16>, vector<64x32xf16>
64+
65+
return %out : vector<64x32xf16>
66+
}
67+
68+
// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0)>
69+
// CHECK-LABEL: @transfer_gather_fold_broadcast
70+
// CHECK: transfer_gather
71+
// CHECK-SAME: indexed_maps = [#[[$MAP]]]
72+
73+
// -----
74+
75+
func.func @transfer_gather_fold_transpose(%indices: vector<64x32xindex>,
76+
%source: tensor<4096x64xf16>)
77+
-> vector<64x32xf16> {
78+
79+
%cst0 = arith.constant 0.0 : f16
80+
%c0 = arith.constant 0 : index
81+
82+
%transposed = vector.transpose %indices, [1, 0] : vector<64x32xindex> to vector<32x64xindex>
83+
84+
%out = iree_vector_ext.transfer_gather %source[%c0, %c0]
85+
[None, %transposed: vector<32x64xindex>], %cst0
86+
{indexed_maps = [affine_map<(d0, d1) -> (d1, d0)>]}
87+
: tensor<4096x64xf16>, vector<64x32xf16>
88+
89+
return %out : vector<64x32xf16>
90+
}
91+
92+
// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
93+
// CHECK-LABEL: @transfer_gather_fold_transpose
94+
// CHECK: transfer_gather
95+
// CHECK-SAME: indexed_maps = [#[[$MAP]]]
96+
97+
// -----
98+
99+
func.func @transfer_gather_fold_step(%indices: vector<64x32xindex>,
100+
%source: tensor<4096x64xf16>)
101+
-> vector<64x32xf16> {
102+
103+
%cst0 = arith.constant 0.0 : f16
104+
%c0 = arith.constant 0 : index
105+
106+
%step = vector.step : vector<64xindex>
107+
108+
%out = iree_vector_ext.transfer_gather %source[%c0, %c0]
109+
[%step : vector<64xindex>, %indices: vector<64x32xindex>], %cst0
110+
{indexed_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)> ]}
111+
: tensor<4096x64xf16>, vector<64x32xf16>
112+
113+
return %out : vector<64x32xf16>
114+
}
115+
116+
// CHECK-LABEL: @transfer_gather_fold_step
117+
// CHECK-SAME: %[[ARG1:.*]]: vector<64x32xindex>
118+
// CHECK: transfer_gather
119+
// CHECK-SAME: [None, %[[ARG1]]
120+
121+
// -----
122+
123+
func.func @transfer_gather_fold_single_element(%scalar: vector<1xindex>,
124+
%indices: vector<64x1xindex>,
125+
%source: tensor<4096x64xf16>)
126+
-> vector<64x1xf16> {
127+
128+
%cst0 = arith.constant 0.0 : f16
129+
%c0 = arith.constant 0 : index
130+
131+
%out = iree_vector_ext.transfer_gather %source[%c0, %c0]
132+
[%scalar : vector<1xindex>, %indices: vector<64x1xindex>], %cst0
133+
{indexed_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)> ]}
134+
: tensor<4096x64xf16>, vector<64x1xf16>
135+
136+
return %out : vector<64x1xf16>
137+
}
138+
139+
// CHECK-LABEL: transfer_gather_fold_single_element
140+
// CHECK-SAME: %{{.*}}: vector<1xindex>, %[[ARG1:.*]]: vector<64x1xindex>
141+
// CHECK: transfer_gather
142+
// CHECK-SAME: [None, %[[ARG1]]
143+
144+
// -----
145+
146+
func.func @transfer_gather_fold_contigious_load(%scalar: vector<64x1xindex>,
147+
%indices: vector<64x1xindex>,
148+
%source: tensor<4096x64xf16>)
149+
-> vector<64x1xf16> {
150+
151+
%cst0 = arith.constant 0.0 : f16
152+
%c0 = arith.constant 0 : index
153+
154+
%out = iree_vector_ext.transfer_gather %source[%c0, %c0]
155+
[None, None], %cst0 {indexed_maps = []} : tensor<4096x64xf16>, vector<64x1xf16>
156+
157+
return %out : vector<64x1xf16>
158+
}
159+
160+
// CHECK-LABEL: @transfer_gather_fold_contigious_load
161+
// CHECK: vector.transfer_read
162+
// CHECK-NOT: transfer_gather
163+
164+
// -----

0 commit comments

Comments
 (0)