@@ -73,6 +73,8 @@ using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
7373using ATan2OpLowering =
7474 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
7575// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
76+ // TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
77+ // may be better to separate the patterns.
7678template <typename MathOp, typename LLVMOp>
7779struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern <MathOp> {
7880 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
@@ -81,26 +83,29 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
8183 LogicalResult
8284 matchAndRewrite (MathOp op, typename MathOp::Adaptor adaptor,
8385 ConversionPatternRewriter &rewriter) const override {
86+ const auto &typeConverter = *this ->getTypeConverter ();
8487 auto operandType = adaptor.getOperand ().getType ();
85-
86- if (!operandType || ! LLVM::isCompatibleType (operandType) )
88+ auto llvmOperandType = typeConverter. convertType (operandType);
89+ if (!llvmOperandType )
8790 return failure ();
8891
8992 auto loc = op.getLoc ();
9093 auto resultType = op.getResult ().getType ();
94+ auto llvmResultType = typeConverter.convertType (resultType);
95+ if (!llvmResultType)
96+ return failure ();
9197
92- if (!isa<LLVM::LLVMArrayType>(operandType )) {
93- rewriter.replaceOpWithNewOp <LLVMOp>(op, resultType, adaptor. getOperand () ,
94- false );
98+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
99+ rewriter.replaceOpWithNewOp <LLVMOp>(op, llvmResultType ,
100+ adaptor. getOperand (), false );
95101 return success ();
96102 }
97103
98- auto vectorType = dyn_cast<VectorType>(resultType);
99- if (!vectorType)
104+ if (!isa<VectorType>(llvmResultType))
100105 return failure ();
101106
102107 return LLVM::detail::handleMultidimensionalVectors (
103- op.getOperation (), adaptor.getOperands (), * this -> getTypeConverter () ,
108+ op.getOperation (), adaptor.getOperands (), typeConverter ,
104109 [&](Type llvm1DVectorTy, ValueRange operands) {
105110 return rewriter.create <LLVMOp>(loc, llvm1DVectorTy, operands[0 ],
106111 false );
@@ -123,40 +128,42 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
123128 LogicalResult
124129 matchAndRewrite (math::ExpM1Op op, OpAdaptor adaptor,
125130 ConversionPatternRewriter &rewriter) const override {
131+ const auto &typeConverter = *this ->getTypeConverter ();
126132 auto operandType = adaptor.getOperand ().getType ();
127-
128- if (!operandType || ! LLVM::isCompatibleType (operandType) )
133+ auto llvmOperandType = typeConverter. convertType (operandType);
134+ if (!llvmOperandType )
129135 return failure ();
130136
131137 auto loc = op.getLoc ();
132138 auto resultType = op.getResult ().getType ();
133- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
139+ auto floatType = cast<FloatType>(
140+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
134141 auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
135142 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs (op);
136143 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs (op);
137144
138- if (!isa<LLVM::LLVMArrayType>(operandType )) {
145+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
139146 LLVM::ConstantOp one;
140- if (LLVM::isCompatibleVectorType (operandType )) {
147+ if (LLVM::isCompatibleVectorType (llvmOperandType )) {
141148 one = rewriter.create <LLVM::ConstantOp>(
142- loc, operandType,
143- SplatElementsAttr::get (cast<ShapedType>(resultType), floatOne));
149+ loc, llvmOperandType,
150+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType),
151+ floatOne));
144152 } else {
145- one = rewriter.create <LLVM::ConstantOp>(loc, operandType , floatOne);
153+ one = rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType , floatOne);
146154 }
147155 auto exp = rewriter.create <LLVM::ExpOp>(loc, adaptor.getOperand (),
148156 expAttrs.getAttrs ());
149157 rewriter.replaceOpWithNewOp <LLVM::FSubOp>(
150- op, operandType , ValueRange{exp, one}, subAttrs.getAttrs ());
158+ op, llvmOperandType , ValueRange{exp, one}, subAttrs.getAttrs ());
151159 return success ();
152160 }
153161
154- auto vectorType = dyn_cast<VectorType>(resultType);
155- if (!vectorType)
162+ if (!isa<VectorType>(resultType))
156163 return rewriter.notifyMatchFailure (op, " expected vector result type" );
157164
158165 return LLVM::detail::handleMultidimensionalVectors (
159- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
166+ op.getOperation (), adaptor.getOperands (), typeConverter ,
160167 [&](Type llvm1DVectorTy, ValueRange operands) {
161168 auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
162169 auto splatAttr = SplatElementsAttr::get (
@@ -181,41 +188,43 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
181188 LogicalResult
182189 matchAndRewrite (math::Log1pOp op, OpAdaptor adaptor,
183190 ConversionPatternRewriter &rewriter) const override {
191+ const auto &typeConverter = *this ->getTypeConverter ();
184192 auto operandType = adaptor.getOperand ().getType ();
185-
186- if (!operandType || ! LLVM::isCompatibleType (operandType) )
193+ auto llvmOperandType = typeConverter. convertType (operandType);
194+ if (!llvmOperandType )
187195 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
188196
189197 auto loc = op.getLoc ();
190198 auto resultType = op.getResult ().getType ();
191- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
199+ auto floatType = cast<FloatType>(
200+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
192201 auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
193202 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs (op);
194203 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs (op);
195204
196- if (!isa<LLVM::LLVMArrayType>(operandType )) {
205+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
197206 LLVM::ConstantOp one =
198- LLVM::isCompatibleVectorType (operandType )
207+ isa<VectorType>(llvmOperandType )
199208 ? rewriter.create <LLVM::ConstantOp>(
200- loc, operandType ,
201- SplatElementsAttr::get (cast<ShapedType>(resultType ),
209+ loc, llvmOperandType ,
210+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType ),
202211 floatOne))
203- : rewriter.create <LLVM::ConstantOp>(loc, operandType, floatOne);
212+ : rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType,
213+ floatOne);
204214
205215 auto add = rewriter.create <LLVM::FAddOp>(
206- loc, operandType , ValueRange{one, adaptor.getOperand ()},
216+ loc, llvmOperandType , ValueRange{one, adaptor.getOperand ()},
207217 addAttrs.getAttrs ());
208- rewriter.replaceOpWithNewOp <LLVM::LogOp>(op, operandType, ValueRange{add},
209- logAttrs.getAttrs ());
218+ rewriter.replaceOpWithNewOp <LLVM::LogOp>(
219+ op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs ());
210220 return success ();
211221 }
212222
213- auto vectorType = dyn_cast<VectorType>(resultType);
214- if (!vectorType)
223+ if (!isa<VectorType>(resultType))
215224 return rewriter.notifyMatchFailure (op, " expected vector result type" );
216225
217226 return LLVM::detail::handleMultidimensionalVectors (
218- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
227+ op.getOperation (), adaptor.getOperands (), typeConverter ,
219228 [&](Type llvm1DVectorTy, ValueRange operands) {
220229 auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
221230 auto splatAttr = SplatElementsAttr::get (
@@ -241,40 +250,42 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
241250 LogicalResult
242251 matchAndRewrite (math::RsqrtOp op, OpAdaptor adaptor,
243252 ConversionPatternRewriter &rewriter) const override {
253+ const auto &typeConverter = *this ->getTypeConverter ();
244254 auto operandType = adaptor.getOperand ().getType ();
245-
246- if (!operandType || ! LLVM::isCompatibleType (operandType) )
255+ auto llvmOperandType = typeConverter. convertType (operandType);
256+ if (!llvmOperandType )
247257 return failure ();
248258
249259 auto loc = op.getLoc ();
250260 auto resultType = op.getResult ().getType ();
251- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
261+ auto floatType = cast<FloatType>(
262+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
252263 auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
253264 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs (op);
254265 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs (op);
255266
256- if (!isa<LLVM::LLVMArrayType>(operandType )) {
267+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
257268 LLVM::ConstantOp one;
258- if (LLVM::isCompatibleVectorType (operandType )) {
269+ if (isa<VectorType>(llvmOperandType )) {
259270 one = rewriter.create <LLVM::ConstantOp>(
260- loc, operandType,
261- SplatElementsAttr::get (cast<ShapedType>(resultType), floatOne));
271+ loc, llvmOperandType,
272+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType),
273+ floatOne));
262274 } else {
263- one = rewriter.create <LLVM::ConstantOp>(loc, operandType , floatOne);
275+ one = rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType , floatOne);
264276 }
265277 auto sqrt = rewriter.create <LLVM::SqrtOp>(loc, adaptor.getOperand (),
266278 sqrtAttrs.getAttrs ());
267279 rewriter.replaceOpWithNewOp <LLVM::FDivOp>(
268- op, operandType , ValueRange{one, sqrt}, divAttrs.getAttrs ());
280+ op, llvmOperandType , ValueRange{one, sqrt}, divAttrs.getAttrs ());
269281 return success ();
270282 }
271283
272- auto vectorType = dyn_cast<VectorType>(resultType);
273- if (!vectorType)
284+ if (!isa<VectorType>(resultType))
274285 return failure ();
275286
276287 return LLVM::detail::handleMultidimensionalVectors (
277- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
288+ op.getOperation (), adaptor.getOperands (), typeConverter ,
278289 [&](Type llvm1DVectorTy, ValueRange operands) {
279290 auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
280291 auto splatAttr = SplatElementsAttr::get (
@@ -298,13 +309,15 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
298309 LogicalResult
299310 matchAndRewrite (math::IsNaNOp op, OpAdaptor adaptor,
300311 ConversionPatternRewriter &rewriter) const override {
301- auto operandType = adaptor.getOperand ().getType ();
302-
303- if (!operandType || !LLVM::isCompatibleType (operandType))
312+ const auto &typeConverter = *this ->getTypeConverter ();
313+ auto operandType =
314+ typeConverter.convertType (adaptor.getOperand ().getType ());
315+ auto resultType = typeConverter.convertType (op.getResult ().getType ());
316+ if (!operandType || !resultType)
304317 return failure ();
305318
306319 rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(
307- op, op. getType () , adaptor.getOperand (), llvm::fcNan);
320+ op, resultType , adaptor.getOperand (), llvm::fcNan);
308321 return success ();
309322 }
310323};
@@ -315,13 +328,15 @@ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
315328 LogicalResult
316329 matchAndRewrite (math::IsFiniteOp op, OpAdaptor adaptor,
317330 ConversionPatternRewriter &rewriter) const override {
318- auto operandType = adaptor.getOperand ().getType ();
319-
320- if (!operandType || !LLVM::isCompatibleType (operandType))
331+ const auto &typeConverter = *this ->getTypeConverter ();
332+ auto operandType =
333+ typeConverter.convertType (adaptor.getOperand ().getType ());
334+ auto resultType = typeConverter.convertType (op.getResult ().getType ());
335+ if (!operandType || !resultType)
321336 return failure ();
322337
323338 rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(
324- op, op. getType () , adaptor.getOperand (), llvm::fcFinite);
339+ op, resultType , adaptor.getOperand (), llvm::fcFinite);
325340 return success ();
326341 }
327342};
0 commit comments