Skip to content

Commit 132b0a8

Browse files
authored
fix(mlir): Align Math intrinsics lowering with upstream review (#20)
1 parent 23ecfe5 commit 132b0a8

File tree

2 files changed

+138
-71
lines changed

2 files changed

+138
-71
lines changed

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 68 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
6969
using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
7070

7171
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
72+
// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
73+
// may be better to separate the patterns.
7274
template <typename MathOp, typename LLVMOp>
7375
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
7476
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
@@ -77,31 +79,25 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
7779
LogicalResult
7880
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
7981
ConversionPatternRewriter &rewriter) const override {
82+
const auto &typeConverter = *this->getTypeConverter();
8083
auto operandType = adaptor.getOperand().getType();
81-
82-
if (!operandType || !LLVM::isCompatibleType(operandType))
84+
auto llvmOperandType = typeConverter.convertType(operandType);
85+
if (!llvmOperandType)
8386
return failure();
8487

8588
auto loc = op.getLoc();
8689
auto resultType = op.getResult().getType();
87-
const auto &typeConverter = *this->getTypeConverter();
88-
if (!LLVM::isCompatibleType(resultType)) {
89-
resultType = typeConverter.convertType(resultType);
90-
if (!resultType)
91-
return failure();
92-
}
93-
if (operandType != resultType)
94-
return rewriter.notifyMatchFailure(
95-
op, "compatible result type doesn't match operand type");
90+
auto llvmResultType = typeConverter.convertType(resultType);
91+
if (!llvmResultType)
92+
return failure();
9693

97-
if (!isa<LLVM::LLVMArrayType>(operandType)) {
98-
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
99-
false);
94+
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
95+
rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
96+
adaptor.getOperand(), false);
10097
return success();
10198
}
10299

103-
auto vectorType = dyn_cast<VectorType>(resultType);
104-
if (!vectorType)
100+
if (!isa<VectorType>(llvmResultType))
105101
return failure();
106102

107103
return LLVM::detail::handleMultidimensionalVectors(
@@ -128,40 +124,42 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
128124
LogicalResult
129125
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
130126
ConversionPatternRewriter &rewriter) const override {
127+
const auto &typeConverter = *this->getTypeConverter();
131128
auto operandType = adaptor.getOperand().getType();
132-
133-
if (!operandType || !LLVM::isCompatibleType(operandType))
129+
auto llvmOperandType = typeConverter.convertType(operandType);
130+
if (!llvmOperandType)
134131
return failure();
135132

136133
auto loc = op.getLoc();
137134
auto resultType = op.getResult().getType();
138-
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
135+
auto floatType = cast<FloatType>(
136+
typeConverter.convertType(getElementTypeOrSelf(resultType)));
139137
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
140138
ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
141139
ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
142140

143-
if (!isa<LLVM::LLVMArrayType>(operandType)) {
141+
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
144142
LLVM::ConstantOp one;
145-
if (LLVM::isCompatibleVectorType(operandType)) {
143+
if (LLVM::isCompatibleVectorType(llvmOperandType)) {
146144
one = rewriter.create<LLVM::ConstantOp>(
147-
loc, operandType,
148-
SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
145+
loc, llvmOperandType,
146+
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
147+
floatOne));
149148
} else {
150-
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
149+
one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
151150
}
152151
auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
153152
expAttrs.getAttrs());
154153
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
155-
op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
154+
op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
156155
return success();
157156
}
158157

159-
auto vectorType = dyn_cast<VectorType>(resultType);
160-
if (!vectorType)
158+
if (!isa<VectorType>(resultType))
161159
return rewriter.notifyMatchFailure(op, "expected vector result type");
162160

163161
return LLVM::detail::handleMultidimensionalVectors(
164-
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
162+
op.getOperation(), adaptor.getOperands(), typeConverter,
165163
[&](Type llvm1DVectorTy, ValueRange operands) {
166164
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
167165
auto splatAttr = SplatElementsAttr::get(
@@ -186,41 +184,43 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
186184
LogicalResult
187185
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
188186
ConversionPatternRewriter &rewriter) const override {
187+
const auto &typeConverter = *this->getTypeConverter();
189188
auto operandType = adaptor.getOperand().getType();
190-
191-
if (!operandType || !LLVM::isCompatibleType(operandType))
189+
auto llvmOperandType = typeConverter.convertType(operandType);
190+
if (!llvmOperandType)
192191
return rewriter.notifyMatchFailure(op, "unsupported operand type");
193192

194193
auto loc = op.getLoc();
195194
auto resultType = op.getResult().getType();
196-
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
195+
auto floatType = cast<FloatType>(
196+
typeConverter.convertType(getElementTypeOrSelf(resultType)));
197197
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
198198
ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
199199
ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
200200

201-
if (!isa<LLVM::LLVMArrayType>(operandType)) {
201+
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
202202
LLVM::ConstantOp one =
203-
LLVM::isCompatibleVectorType(operandType)
203+
isa<VectorType>(llvmOperandType)
204204
? rewriter.create<LLVM::ConstantOp>(
205-
loc, operandType,
206-
SplatElementsAttr::get(cast<ShapedType>(resultType),
205+
loc, llvmOperandType,
206+
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
207207
floatOne))
208-
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
208+
: rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType,
209+
floatOne);
209210

210211
auto add = rewriter.create<LLVM::FAddOp>(
211-
loc, operandType, ValueRange{one, adaptor.getOperand()},
212+
loc, llvmOperandType, ValueRange{one, adaptor.getOperand()},
212213
addAttrs.getAttrs());
213-
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
214-
logAttrs.getAttrs());
214+
rewriter.replaceOpWithNewOp<LLVM::LogOp>(
215+
op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
215216
return success();
216217
}
217218

218-
auto vectorType = dyn_cast<VectorType>(resultType);
219-
if (!vectorType)
219+
if (!isa<VectorType>(resultType))
220220
return rewriter.notifyMatchFailure(op, "expected vector result type");
221221

222222
return LLVM::detail::handleMultidimensionalVectors(
223-
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
223+
op.getOperation(), adaptor.getOperands(), typeConverter,
224224
[&](Type llvm1DVectorTy, ValueRange operands) {
225225
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
226226
auto splatAttr = SplatElementsAttr::get(
@@ -246,40 +246,42 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
246246
LogicalResult
247247
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
248248
ConversionPatternRewriter &rewriter) const override {
249+
const auto &typeConverter = *this->getTypeConverter();
249250
auto operandType = adaptor.getOperand().getType();
250-
251-
if (!operandType || !LLVM::isCompatibleType(operandType))
251+
auto llvmOperandType = typeConverter.convertType(operandType);
252+
if (!llvmOperandType)
252253
return failure();
253254

254255
auto loc = op.getLoc();
255256
auto resultType = op.getResult().getType();
256-
auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
257+
auto floatType = cast<FloatType>(
258+
typeConverter.convertType(getElementTypeOrSelf(resultType)));
257259
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
258260
ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
259261
ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
260262

261-
if (!isa<LLVM::LLVMArrayType>(operandType)) {
263+
if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
262264
LLVM::ConstantOp one;
263-
if (LLVM::isCompatibleVectorType(operandType)) {
265+
if (isa<VectorType>(llvmOperandType)) {
264266
one = rewriter.create<LLVM::ConstantOp>(
265-
loc, operandType,
266-
SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
267+
loc, llvmOperandType,
268+
SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
269+
floatOne));
267270
} else {
268-
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
271+
one = rewriter.create<LLVM::ConstantOp>(loc, llvmOperandType, floatOne);
269272
}
270273
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
271274
sqrtAttrs.getAttrs());
272275
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
273-
op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
276+
op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
274277
return success();
275278
}
276279

277-
auto vectorType = dyn_cast<VectorType>(resultType);
278-
if (!vectorType)
280+
if (!isa<VectorType>(resultType))
279281
return failure();
280282

281283
return LLVM::detail::handleMultidimensionalVectors(
282-
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
284+
op.getOperation(), adaptor.getOperands(), typeConverter,
283285
[&](Type llvm1DVectorTy, ValueRange operands) {
284286
auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
285287
auto splatAttr = SplatElementsAttr::get(
@@ -303,13 +305,15 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
303305
LogicalResult
304306
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
305307
ConversionPatternRewriter &rewriter) const override {
306-
auto operandType = adaptor.getOperand().getType();
307-
308-
if (!operandType || !LLVM::isCompatibleType(operandType))
308+
const auto &typeConverter = *this->getTypeConverter();
309+
auto operandType =
310+
typeConverter.convertType(adaptor.getOperand().getType());
311+
auto resultType = typeConverter.convertType(op.getResult().getType());
312+
if (!operandType || !resultType)
309313
return failure();
310314

311315
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
312-
op, op.getType(), adaptor.getOperand(), llvm::fcNan);
316+
op, resultType, adaptor.getOperand(), llvm::fcNan);
313317
return success();
314318
}
315319
};
@@ -320,13 +324,15 @@ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
320324
LogicalResult
321325
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
322326
ConversionPatternRewriter &rewriter) const override {
323-
auto operandType = adaptor.getOperand().getType();
324-
325-
if (!operandType || !LLVM::isCompatibleType(operandType))
327+
const auto &typeConverter = *this->getTypeConverter();
328+
auto operandType =
329+
typeConverter.convertType(adaptor.getOperand().getType());
330+
auto resultType = typeConverter.convertType(op.getResult().getType());
331+
if (!operandType || !resultType)
326332
return failure();
327333

328334
rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
329-
op, op.getType(), adaptor.getOperand(), llvm::fcFinite);
335+
op, resultType, adaptor.getOperand(), llvm::fcFinite);
330336
return success();
331337
}
332338
};

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ func.func @absi(%arg0: i32) -> i32 {
2929

3030
// -----
3131

32-
// CHECK-LABEL: func @absi_0d_vec(
33-
// CHECK-SAME: i32
34-
func.func @absi_0d_vec(%arg0 : vector<i32>) {
32+
// CHECK-LABEL: func @absi_0dvector(
33+
// CHECK-SAME: vector<i32>
34+
func.func @absi_0dvector(%arg0 : vector<i32>) {
3535
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
3636
// CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
3737
%0 = math.absi %arg0 : vector<i32>
@@ -102,6 +102,19 @@ func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
102102

103103
// -----
104104

105+
// CHECK-LABEL: func @log1p_0dvector(
106+
// CHECK-SAME: vector<f32>
107+
func.func @log1p_0dvector(%arg0 : vector<f32>) {
108+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
109+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
110+
// CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[CAST]] : vector<1xf32>
111+
// CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) : (vector<1xf32>) -> vector<1xf32>
112+
%0 = math.log1p %arg0 : vector<f32>
113+
func.return
114+
}
115+
116+
// -----
117+
105118
// CHECK-LABEL: func @expm1(
106119
// CHECK-SAME: f32
107120
func.func @expm1(%arg0 : f32) {
@@ -162,6 +175,19 @@ func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
162175

163176
// -----
164177

178+
// CHECK-LABEL: func @expm1_0dvector(
179+
// CHECK-SAME: vector<f32>
180+
func.func @expm1_0dvector(%arg0 : vector<f32>) {
181+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
182+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
183+
// CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[CAST]]) : (vector<1xf32>) -> vector<1xf32>
184+
// CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<1xf32>
185+
%0 = math.expm1 %arg0 : vector<f32>
186+
func.return
187+
}
188+
189+
// -----
190+
165191
// CHECK-LABEL: func @rsqrt(
166192
// CHECK-SAME: f32
167193
func.func @rsqrt(%arg0 : f32) {
@@ -174,6 +200,19 @@ func.func @rsqrt(%arg0 : f32) {
174200

175201
// -----
176202

203+
// CHECK-LABEL: func @rsqrt_0dvector(
204+
// CHECK-SAME: vector<f32>
205+
func.func @rsqrt_0dvector(%arg0 : vector<f32>) {
206+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
207+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<1xf32>) : vector<1xf32>
208+
// CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[CAST]]) : (vector<1xf32>) -> vector<1xf32>
209+
// CHECK: %[[SUB:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<1xf32>
210+
%0 = math.rsqrt %arg0 : vector<f32>
211+
func.return
212+
}
213+
214+
// -----
215+
177216
// CHECK-LABEL: func @trigonometrics
178217
// CHECK-SAME: [[ARG0:%.+]]: f32
179218
func.func @trigonometrics(%arg0: f32) {
@@ -214,9 +253,9 @@ func.func @ctlz(%arg0 : i32) {
214253
func.return
215254
}
216255

217-
// CHECK-LABEL: func @ctlz_0d_vec(
218-
// CHECK-SAME: i32
219-
func.func @ctlz_0d_vec(%arg0 : vector<i32>) {
256+
// CHECK-LABEL: func @ctlz_0dvector(
257+
// CHECK-SAME: vector<i32>
258+
func.func @ctlz_0dvector(%arg0 : vector<i32>) {
220259
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
221260
// CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
222261
%0 = math.ctlz %arg0 : vector<i32>
@@ -235,9 +274,9 @@ func.func @cttz(%arg0 : i32) {
235274

236275
// -----
237276

238-
// CHECK-LABEL: func @cttz_0d_vec(
239-
// CHECK-SAME: i32
240-
func.func @cttz_0d_vec(%arg0 : vector<i32>) {
277+
// CHECK-LABEL: func @cttz_0dvector(
278+
// CHECK-SAME: vector<i32>
279+
func.func @cttz_0dvector(%arg0 : vector<i32>) {
241280
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
242281
// CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
243282
%0 = math.cttz %arg0 : vector<i32>
@@ -306,6 +345,17 @@ func.func @isnan_double(%arg0 : f64) {
306345

307346
// -----
308347

348+
// CHECK-LABEL: func @isnan_0dvector(
349+
// CHECK-SAME: vector<f32>
350+
func.func @isnan_0dvector(%arg0 : vector<f32>) {
351+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
352+
// CHECK: "llvm.intr.is.fpclass"(%[[CAST]]) <{bit = 3 : i32}> : (vector<1xf32>) -> vector<1xi1>
353+
%0 = math.isnan %arg0 : vector<f32>
354+
func.return
355+
}
356+
357+
// -----
358+
309359
// CHECK-LABEL: func @isfinite_double(
310360
// CHECK-SAME: f64
311361
func.func @isfinite_double(%arg0 : f64) {
@@ -316,6 +366,17 @@ func.func @isfinite_double(%arg0 : f64) {
316366

317367
// -----
318368

369+
// CHECK-LABEL: func @isfinite_0dvector(
370+
// CHECK-SAME: vector<f32>
371+
func.func @isfinite_0dvector(%arg0 : vector<f32>) {
372+
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<f32> to vector<1xf32>
373+
// CHECK: "llvm.intr.is.fpclass"(%[[CAST]]) <{bit = 504 : i32}> : (vector<1xf32>) -> vector<1xi1>
374+
%0 = math.isfinite %arg0 : vector<f32>
375+
func.return
376+
}
377+
378+
// -----
379+
319380
// CHECK-LABEL: func @rsqrt_double(
320381
// CHECK-SAME: f64
321382
func.func @rsqrt_double(%arg0 : f64) {

0 commit comments

Comments
 (0)