Skip to content

Commit 96007f1

Browse files
authored
fix(mlir): fix Math-to-LLVM intrinsic conversions for 0D-vector types (#13)
`vector<t>` types are not compatible with the LLVM type system, and must be explicitly converted into `vector<1xt>` when lowering. Employ this rule within the conversion pattern for `math.ctlz`, `.cttz` and `.absi` intrinsics.
1 parent eef4f66 commit 96007f1

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
8484

8585
auto loc = op.getLoc();
8686
auto resultType = op.getResult().getType();
87+
const auto &typeConverter = *this->getTypeConverter();
88+
if (!LLVM::isCompatibleType(resultType)) {
89+
resultType = typeConverter.convertType(resultType);
90+
if (!resultType)
91+
return failure();
92+
}
93+
if (operandType != resultType)
94+
return rewriter.notifyMatchFailure(
95+
op, "compatible result type doesn't match operand type");
8796

8897
if (!isa<LLVM::LLVMArrayType>(operandType)) {
8998
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
@@ -96,7 +105,7 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
96105
return failure();
97106

98107
return LLVM::detail::handleMultidimensionalVectors(
99-
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
108+
op.getOperation(), adaptor.getOperands(), typeConverter,
100109
[&](Type llvm1DVectorTy, ValueRange operands) {
101110
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
102111
false);

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
1919

2020
// -----
2121

22+
// CHECK-LABEL: func @absi(
23+
// CHECK-SAME: i32
2224
func.func @absi(%arg0: i32) -> i32 {
2325
// CHECK: = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32
2426
%0 = math.absi %arg0 : i32
@@ -27,6 +29,17 @@ func.func @absi(%arg0: i32) -> i32 {
2729

2830
// -----
2931

32+
// CHECK-LABEL: func @absi_0d_vec(
33+
// CHECK-SAME: i32
34+
func.func @absi_0d_vec(%arg0 : vector<i32>) {
35+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
36+
// CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
37+
%0 = math.absi %arg0 : vector<i32>
38+
func.return
39+
}
40+
41+
// -----
42+
3043
// CHECK-LABEL: func @log1p(
3144
// CHECK-SAME: f32
3245
func.func @log1p(%arg0 : f32) {
@@ -201,6 +214,15 @@ func.func @ctlz(%arg0 : i32) {
201214
func.return
202215
}
203216

217+
// CHECK-LABEL: func @ctlz_0d_vec(
218+
// CHECK-SAME: i32
219+
func.func @ctlz_0d_vec(%arg0 : vector<i32>) {
220+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
221+
// CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
222+
%0 = math.ctlz %arg0 : vector<i32>
223+
func.return
224+
}
225+
204226
// -----
205227

206228
// CHECK-LABEL: func @cttz(
@@ -213,6 +235,17 @@ func.func @cttz(%arg0 : i32) {
213235

214236
// -----
215237

238+
// CHECK-LABEL: func @cttz_0d_vec(
239+
// CHECK-SAME: i32
240+
func.func @cttz_0d_vec(%arg0 : vector<i32>) {
241+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
242+
// CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
243+
%0 = math.cttz %arg0 : vector<i32>
244+
func.return
245+
}
246+
247+
// -----
248+
216249
// CHECK-LABEL: func @cttz_vec(
217250
// CHECK-SAME: i32
218251
func.func @cttz_vec(%arg0 : vector<4xi32>) {

0 commit comments

Comments
 (0)