Skip to content

Commit 45d715a

Browse files
committed
Add IndexBitWidth option to vector-to-llvm pass
Change-Id: I1ad6f77183f1f1faf25e935131de4ef3a4334150
1 parent af64f0a commit 45d715a

File tree

6 files changed

+427
-17
lines changed

6 files changed

+427
-17
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14141414
"vector::VectorTransformsOptions",
14151415
/*default=*/"vector::VectorTransformsOptions()",
14161416
"Options to lower some operations like contractions and transposes.">,
1417+
Option<"indexBitwidth", "index-bitwidth", "unsigned",
1418+
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
1419+
"Bitwidth of the index type, 0 to use size of machine word">,
14171420
];
14181421
}
14191422

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,8 +1439,6 @@ class VectorTypeCastOpConversion
14391439
if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
14401440
return failure();
14411441

1442-
auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1443-
14441442
// Create descriptor.
14451443
auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
14461444
// Set allocated ptr.
@@ -1451,21 +1449,26 @@ class VectorTypeCastOpConversion
14511449
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
14521450
desc.setAlignedPtr(rewriter, loc, ptr);
14531451
// Fill offset 0.
1454-
auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1455-
auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1452+
1453+
auto idxType = rewriter.getIndexType();
1454+
auto zero = rewriter.create<LLVM::ConstantOp>(
1455+
loc, typeConverter->convertType(idxType),
1456+
rewriter.getIntegerAttr(idxType, 0));
14561457
desc.setOffset(rewriter, loc, zero);
14571458

14581459
// Fill size and stride descriptors in memref.
14591460
for (const auto &indexedSize :
14601461
llvm::enumerate(targetMemRefType.getShape())) {
14611462
int64_t index = indexedSize.index();
1462-
auto sizeAttr =
1463-
rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1464-
auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1463+
1464+
auto size = rewriter.create<LLVM::ConstantOp>(
1465+
loc, typeConverter->convertType(idxType),
1466+
rewriter.getIntegerAttr(idxType, indexedSize.value()));
14651467
desc.setSize(rewriter, loc, index, size);
1466-
auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1467-
(*targetStrides)[index]);
1468-
auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1468+
1469+
auto stride = rewriter.create<LLVM::ConstantOp>(
1470+
loc, typeConverter->convertType(idxType),
1471+
rewriter.getIntegerAttr(idxType, (*targetStrides)[index]));
14691472
desc.setStride(rewriter, loc, index, stride);
14701473
}
14711474

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
10-
10+
#include "mlir/Analysis/DataLayoutAnalysis.h"
1111
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1212
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1313
#include "mlir/Dialect/AMX/AMXDialect.h"
@@ -64,6 +64,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
6464
// Perform progressive lowering of operations on slices and all contraction
6565
// operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
6666
// applies folding and DCE.
67+
Operation *op = getOperation();
68+
const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
6769
{
6870
RewritePatternSet patterns(&getContext());
6971
populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -83,10 +85,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8385
populateVectorRankReducingFMAPattern(patterns);
8486
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
8587
}
86-
8788
// Convert to the LLVM IR dialect.
88-
LowerToLLVMOptions options(&getContext());
89-
LLVMTypeConverter converter(&getContext(), options);
89+
LowerToLLVMOptions options(&getContext(),
90+
dataLayoutAnalysis.getAtOrAbove(op));
91+
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
92+
options.overrideIndexBitwidth(indexBitwidth);
93+
LLVMTypeConverter converter(&getContext(), options, &dataLayoutAnalysis);
9094
RewritePatternSet patterns(&getContext());
9195
populateVectorTransferLoweringPatterns(patterns);
9296
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s
2+
3+
module attributes {dlti.dl_spec = #dlti.dl_spec< #dlti.dl_entry<index, 32>>} {
4+
// CHECK-LABEL: func.func @broadcast_vec2d_from_vec0d(
5+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<f32>) -> vector<3x2xf32> {
6+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<f32> to vector<1xf32>
7+
// CHECK: %[[VAL_2:.*]] = ub.poison : vector<3x2xf32>
8+
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
9+
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : index) : i32
10+
// CHECK: %[[VAL_5:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_4]] : i32] : vector<1xf32>
11+
// CHECK: %[[VAL_6:.*]] = llvm.mlir.poison : vector<2xf32>
12+
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : i32
13+
// CHECK: %[[VAL_8:.*]] = llvm.insertelement %[[VAL_5]], %[[VAL_6]]{{\[}}%[[VAL_7]] : i32] : vector<2xf32>
14+
// CHECK: %[[VAL_9:.*]] = llvm.shufflevector %[[VAL_8]], %[[VAL_6]] [0, 0] : vector<2xf32>
15+
// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_3]][0] : !llvm.array<3 x vector<2xf32>>
16+
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_10]][1] : !llvm.array<3 x vector<2xf32>>
17+
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_11]][2] : !llvm.array<3 x vector<2xf32>>
18+
// CHECK: %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32>
19+
// CHECK: return %[[VAL_13]] : vector<3x2xf32>
20+
// CHECK: }
21+
func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
22+
%0 = vector.broadcast %arg0 : vector<f32> to vector<3x2xf32>
23+
return %0 : vector<3x2xf32>
24+
}
25+
}

0 commit comments

Comments
 (0)