@@ -69,6 +69,8 @@ using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
6969using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
7070
7171// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
72+ // TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
73+ // may be better to separate the patterns.
7274template <typename MathOp, typename LLVMOp>
7375struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern <MathOp> {
7476 using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
@@ -77,31 +79,25 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
7779 LogicalResult
7880 matchAndRewrite (MathOp op, typename MathOp::Adaptor adaptor,
7981 ConversionPatternRewriter &rewriter) const override {
82+ const auto &typeConverter = *this ->getTypeConverter ();
8083 auto operandType = adaptor.getOperand ().getType ();
81-
82- if (!operandType || ! LLVM::isCompatibleType (operandType) )
84+ auto llvmOperandType = typeConverter. convertType (operandType);
85+ if (!llvmOperandType )
8386 return failure ();
8487
8588 auto loc = op.getLoc ();
8689 auto resultType = op.getResult ().getType ();
87- const auto &typeConverter = *this ->getTypeConverter ();
88- if (!LLVM::isCompatibleType (resultType)) {
89- resultType = typeConverter.convertType (resultType);
90- if (!resultType)
91- return failure ();
92- }
93- if (operandType != resultType)
94- return rewriter.notifyMatchFailure (
95- op, " compatible result type doesn't match operand type" );
90+ auto llvmResultType = typeConverter.convertType (resultType);
91+ if (!llvmResultType)
92+ return failure ();
9693
97- if (!isa<LLVM::LLVMArrayType>(operandType )) {
98- rewriter.replaceOpWithNewOp <LLVMOp>(op, resultType, adaptor. getOperand () ,
99- false );
94+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
95+ rewriter.replaceOpWithNewOp <LLVMOp>(op, llvmResultType ,
96+ adaptor. getOperand (), false );
10097 return success ();
10198 }
10299
103- auto vectorType = dyn_cast<VectorType>(resultType);
104- if (!vectorType)
100+ if (!isa<VectorType>(llvmResultType))
105101 return failure ();
106102
107103 return LLVM::detail::handleMultidimensionalVectors (
@@ -128,40 +124,42 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
128124 LogicalResult
129125 matchAndRewrite (math::ExpM1Op op, OpAdaptor adaptor,
130126 ConversionPatternRewriter &rewriter) const override {
127+ const auto &typeConverter = *this ->getTypeConverter ();
131128 auto operandType = adaptor.getOperand ().getType ();
132-
133- if (!operandType || ! LLVM::isCompatibleType (operandType) )
129+ auto llvmOperandType = typeConverter. convertType (operandType);
130+ if (!llvmOperandType )
134131 return failure ();
135132
136133 auto loc = op.getLoc ();
137134 auto resultType = op.getResult ().getType ();
138- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
135+ auto floatType = cast<FloatType>(
136+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
139137 auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
140138 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs (op);
141139 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs (op);
142140
143- if (!isa<LLVM::LLVMArrayType>(operandType )) {
141+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
144142 LLVM::ConstantOp one;
145- if (LLVM::isCompatibleVectorType (operandType )) {
143+ if (LLVM::isCompatibleVectorType (llvmOperandType )) {
146144 one = rewriter.create <LLVM::ConstantOp>(
147- loc, operandType,
148- SplatElementsAttr::get (cast<ShapedType>(resultType), floatOne));
145+ loc, llvmOperandType,
146+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType),
147+ floatOne));
149148 } else {
150- one = rewriter.create <LLVM::ConstantOp>(loc, operandType , floatOne);
149+ one = rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType , floatOne);
151150 }
152151 auto exp = rewriter.create <LLVM::ExpOp>(loc, adaptor.getOperand (),
153152 expAttrs.getAttrs ());
154153 rewriter.replaceOpWithNewOp <LLVM::FSubOp>(
155- op, operandType , ValueRange{exp, one}, subAttrs.getAttrs ());
154+ op, llvmOperandType , ValueRange{exp, one}, subAttrs.getAttrs ());
156155 return success ();
157156 }
158157
159- auto vectorType = dyn_cast<VectorType>(resultType);
160- if (!vectorType)
158+ if (!isa<VectorType>(resultType))
161159 return rewriter.notifyMatchFailure (op, " expected vector result type" );
162160
163161 return LLVM::detail::handleMultidimensionalVectors (
164- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
162+ op.getOperation (), adaptor.getOperands (), typeConverter ,
165163 [&](Type llvm1DVectorTy, ValueRange operands) {
166164 auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
167165 auto splatAttr = SplatElementsAttr::get (
@@ -186,41 +184,43 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
186184 LogicalResult
187185 matchAndRewrite (math::Log1pOp op, OpAdaptor adaptor,
188186 ConversionPatternRewriter &rewriter) const override {
187+ const auto &typeConverter = *this ->getTypeConverter ();
189188 auto operandType = adaptor.getOperand ().getType ();
190-
191- if (!operandType || ! LLVM::isCompatibleType (operandType) )
189+ auto llvmOperandType = typeConverter. convertType (operandType);
190+ if (!llvmOperandType )
192191 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
193192
194193 auto loc = op.getLoc ();
195194 auto resultType = op.getResult ().getType ();
196- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
195+ auto floatType = cast<FloatType>(
196+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
197197 auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
198198 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs (op);
199199 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs (op);
200200
201- if (!isa<LLVM::LLVMArrayType>(operandType )) {
201+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
202202 LLVM::ConstantOp one =
203- LLVM::isCompatibleVectorType (operandType )
203+ isa<VectorType>(llvmOperandType )
204204 ? rewriter.create <LLVM::ConstantOp>(
205- loc, operandType ,
206- SplatElementsAttr::get (cast<ShapedType>(resultType ),
205+ loc, llvmOperandType ,
206+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType ),
207207 floatOne))
208- : rewriter.create <LLVM::ConstantOp>(loc, operandType, floatOne);
208+ : rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType,
209+ floatOne);
209210
210211 auto add = rewriter.create <LLVM::FAddOp>(
211- loc, operandType , ValueRange{one, adaptor.getOperand ()},
212+ loc, llvmOperandType , ValueRange{one, adaptor.getOperand ()},
212213 addAttrs.getAttrs ());
213- rewriter.replaceOpWithNewOp <LLVM::LogOp>(op, operandType, ValueRange{add},
214- logAttrs.getAttrs ());
214+ rewriter.replaceOpWithNewOp <LLVM::LogOp>(
215+ op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs ());
215216 return success ();
216217 }
217218
218- auto vectorType = dyn_cast<VectorType>(resultType);
219- if (!vectorType)
219+ if (!isa<VectorType>(resultType))
220220 return rewriter.notifyMatchFailure (op, " expected vector result type" );
221221
222222 return LLVM::detail::handleMultidimensionalVectors (
223- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
223+ op.getOperation (), adaptor.getOperands (), typeConverter ,
224224 [&](Type llvm1DVectorTy, ValueRange operands) {
225225 auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
226226 auto splatAttr = SplatElementsAttr::get (
@@ -246,40 +246,42 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
246246 LogicalResult
247247 matchAndRewrite (math::RsqrtOp op, OpAdaptor adaptor,
248248 ConversionPatternRewriter &rewriter) const override {
249+ const auto &typeConverter = *this ->getTypeConverter ();
249250 auto operandType = adaptor.getOperand ().getType ();
250-
251- if (!operandType || ! LLVM::isCompatibleType (operandType) )
251+ auto llvmOperandType = typeConverter. convertType (operandType);
252+ if (!llvmOperandType )
252253 return failure ();
253254
254255 auto loc = op.getLoc ();
255256 auto resultType = op.getResult ().getType ();
256- auto floatType = cast<FloatType>(getElementTypeOrSelf (resultType));
257+ auto floatType = cast<FloatType>(
258+ typeConverter.convertType (getElementTypeOrSelf (resultType)));
257259 auto floatOne = rewriter.getFloatAttr (floatType, 1.0 );
258260 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs (op);
259261 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs (op);
260262
261- if (!isa<LLVM::LLVMArrayType>(operandType )) {
263+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType )) {
262264 LLVM::ConstantOp one;
263- if (LLVM::isCompatibleVectorType (operandType )) {
265+ if (isa<VectorType>(llvmOperandType )) {
264266 one = rewriter.create <LLVM::ConstantOp>(
265- loc, operandType,
266- SplatElementsAttr::get (cast<ShapedType>(resultType), floatOne));
267+ loc, llvmOperandType,
268+ SplatElementsAttr::get (cast<ShapedType>(llvmOperandType),
269+ floatOne));
267270 } else {
268- one = rewriter.create <LLVM::ConstantOp>(loc, operandType , floatOne);
271+ one = rewriter.create <LLVM::ConstantOp>(loc, llvmOperandType , floatOne);
269272 }
270273 auto sqrt = rewriter.create <LLVM::SqrtOp>(loc, adaptor.getOperand (),
271274 sqrtAttrs.getAttrs ());
272275 rewriter.replaceOpWithNewOp <LLVM::FDivOp>(
273- op, operandType , ValueRange{one, sqrt}, divAttrs.getAttrs ());
276+ op, llvmOperandType , ValueRange{one, sqrt}, divAttrs.getAttrs ());
274277 return success ();
275278 }
276279
277- auto vectorType = dyn_cast<VectorType>(resultType);
278- if (!vectorType)
280+ if (!isa<VectorType>(resultType))
279281 return failure ();
280282
281283 return LLVM::detail::handleMultidimensionalVectors (
282- op.getOperation (), adaptor.getOperands (), * getTypeConverter () ,
284+ op.getOperation (), adaptor.getOperands (), typeConverter ,
283285 [&](Type llvm1DVectorTy, ValueRange operands) {
284286 auto numElements = LLVM::getVectorNumElements (llvm1DVectorTy);
285287 auto splatAttr = SplatElementsAttr::get (
@@ -303,13 +305,15 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
303305 LogicalResult
304306 matchAndRewrite (math::IsNaNOp op, OpAdaptor adaptor,
305307 ConversionPatternRewriter &rewriter) const override {
306- auto operandType = adaptor.getOperand ().getType ();
307-
308- if (!operandType || !LLVM::isCompatibleType (operandType))
308+ const auto &typeConverter = *this ->getTypeConverter ();
309+ auto operandType =
310+ typeConverter.convertType (adaptor.getOperand ().getType ());
311+ auto resultType = typeConverter.convertType (op.getResult ().getType ());
312+ if (!operandType || !resultType)
309313 return failure ();
310314
311315 rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(
312- op, op. getType () , adaptor.getOperand (), llvm::fcNan);
316+ op, resultType , adaptor.getOperand (), llvm::fcNan);
313317 return success ();
314318 }
315319};
@@ -320,13 +324,15 @@ struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
320324 LogicalResult
321325 matchAndRewrite (math::IsFiniteOp op, OpAdaptor adaptor,
322326 ConversionPatternRewriter &rewriter) const override {
323- auto operandType = adaptor.getOperand ().getType ();
324-
325- if (!operandType || !LLVM::isCompatibleType (operandType))
327+ const auto &typeConverter = *this ->getTypeConverter ();
328+ auto operandType =
329+ typeConverter.convertType (adaptor.getOperand ().getType ());
330+ auto resultType = typeConverter.convertType (op.getResult ().getType ());
331+ if (!operandType || !resultType)
326332 return failure ();
327333
328334 rewriter.replaceOpWithNewOp <LLVM::IsFPClass>(
329- op, op. getType () , adaptor.getOperand (), llvm::fcFinite);
335+ op, resultType , adaptor.getOperand (), llvm::fcFinite);
330336 return success ();
331337 }
332338};
0 commit comments