Skip to content

Commit 102fcf0

Browse files
committed
Add IndexBitWidth option to vector-to-llvm pass
Change-Id: I1ad6f77183f1f1faf25e935131de4ef3a4334150
1 parent af64f0a commit 102fcf0

File tree

6 files changed

+417
-38
lines changed

6 files changed

+417
-38
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14141414
"vector::VectorTransformsOptions",
14151415
/*default=*/"vector::VectorTransformsOptions()",
14161416
"Options to lower some operations like contractions and transposes.">,
1417+
Option<"indexBitwidth", "index-bitwidth", "unsigned",
1418+
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
1419+
"Bitwidth of the index type, 0 to use size of machine word">,
14171420
];
14181421
}
14191422

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,9 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
4949
int64_t pos) {
5050
assert(rank > 0 && "0-D vector corner case should have been handled already");
5151
if (rank == 1) {
52-
auto idxType = rewriter.getIndexType();
52+
auto idxType = typeConverter.convertType(rewriter.getIndexType());
5353
auto constant = rewriter.create<LLVM::ConstantOp>(
54-
loc, typeConverter.convertType(idxType),
55-
rewriter.getIntegerAttr(idxType, pos));
54+
loc, idxType, rewriter.getIntegerAttr(idxType, pos));
5655
return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
5756
constant);
5857
}
@@ -64,10 +63,9 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
6463
const LLVMTypeConverter &typeConverter, Location loc,
6564
Value val, Type llvmType, int64_t rank, int64_t pos) {
6665
if (rank <= 1) {
67-
auto idxType = rewriter.getIndexType();
66+
auto idxType = typeConverter.convertType(rewriter.getIndexType());
6867
auto constant = rewriter.create<LLVM::ConstantOp>(
69-
loc, typeConverter.convertType(idxType),
70-
rewriter.getIntegerAttr(idxType, pos));
68+
loc, idxType, rewriter.getIntegerAttr(idxType, pos));
7169
return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
7270
constant);
7371
}
@@ -1064,10 +1062,9 @@ class VectorExtractElementOpConversion
10641062

10651063
if (vectorType.getRank() == 0) {
10661064
Location loc = extractEltOp.getLoc();
1067-
auto idxType = rewriter.getIndexType();
1065+
auto idxType = typeConverter->convertType(rewriter.getIndexType());
10681066
auto zero = rewriter.create<LLVM::ConstantOp>(
1069-
loc, typeConverter->convertType(idxType),
1070-
rewriter.getIntegerAttr(idxType, 0));
1067+
loc, idxType, rewriter.getIntegerAttr(idxType, 0));
10711068
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
10721069
extractEltOp, llvmType, adaptor.getVector(), zero);
10731070
return success();
@@ -1198,10 +1195,9 @@ class VectorInsertElementOpConversion
11981195

11991196
if (vectorType.getRank() == 0) {
12001197
Location loc = insertEltOp.getLoc();
1201-
auto idxType = rewriter.getIndexType();
1198+
auto idxType = typeConverter->convertType(rewriter.getIndexType());
12021199
auto zero = rewriter.create<LLVM::ConstantOp>(
1203-
loc, typeConverter->convertType(idxType),
1204-
rewriter.getIntegerAttr(idxType, 0));
1200+
loc, idxType, rewriter.getIntegerAttr(idxType, 0));
12051201
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
12061202
insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
12071203
return success();
@@ -1439,8 +1435,6 @@ class VectorTypeCastOpConversion
14391435
if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
14401436
return failure();
14411437

1442-
auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1443-
14441438
// Create descriptor.
14451439
auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
14461440
// Set allocated ptr.
@@ -1451,21 +1445,24 @@ class VectorTypeCastOpConversion
14511445
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
14521446
desc.setAlignedPtr(rewriter, loc, ptr);
14531447
// Fill offset 0.
1454-
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1455-
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1448+
1449+
auto idxType = typeConverter->convertType(rewriter.getIndexType());
1450+
auto zero = rewriter.create<LLVM::ConstantOp>(
1451+
loc, idxType, rewriter.getIntegerAttr(idxType, 0));
14561452
desc.setOffset(rewriter, loc, zero);
14571453

14581454
// Fill size and stride descriptors in memref.
14591455
for (const auto &indexedSize :
14601456
llvm::enumerate(targetMemRefType.getShape())) {
14611457
int64_t index = indexedSize.index();
1462-
auto sizeAttr =
1463-
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1464-
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1458+
1459+
auto size = rewriter.create<LLVM::ConstantOp>(
1460+
loc, idxType, rewriter.getIntegerAttr(idxType, indexedSize.value()));
14651461
desc.setSize(rewriter, loc, index, size);
1466-
auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1467-
(*targetStrides)[index]);
1468-
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1462+
1463+
auto stride = rewriter.create<LLVM::ConstantOp>(
1464+
loc, idxType,
1465+
rewriter.getIntegerAttr(idxType, (*targetStrides)[index]));
14691466
desc.setStride(rewriter, loc, index, stride);
14701467
}
14711468

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8686

8787
// Convert to the LLVM IR dialect.
8888
LowerToLLVMOptions options(&getContext());
89+
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
90+
options.overrideIndexBitwidth(indexBitwidth);
8991
LLVMTypeConverter converter(&getContext(), options);
9092
RewritePatternSet patterns(&getContext());
9193
populateVectorTransferLoweringPatterns(patterns);

0 commit comments

Comments
 (0)