2222#include " mlir/Transforms/DialectConversion.h"
2323#include " llvm/Support/FormatVariadic.h"
2424
25- #include " ../GPUCommon/GPUOpsLowering.h"
26-
2725namespace mlir {
2826#define GEN_PASS_DEF_CONVERTMATHTOXEVM
2927#include " mlir/Conversion/Passes.h.inc"
@@ -48,8 +46,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
4846 return failure ();
4947
5048 arith::FastMathFlags fastFlags = op.getFastmath ();
51- if (!(static_cast <uint32_t >(fastFlags) &
52- static_cast <uint32_t >(arith::FastMathFlags::afn)))
49+ if (!arith::bitEnumContainsAll (fastFlags, arith::FastMathFlags::afn))
5350 return rewriter.notifyMatchFailure (op, " not a fastmath `afn` operation" );
5451
5552 SmallVector<Type, 1 > operandTypes;
@@ -83,9 +80,9 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
8380 }
8481
8582 inline bool isSPIRVCompatibleFloatOrVec (Type type) const {
86- if (type.isFloat ()) {
83+ if (type.isFloat ())
8784 return true ;
88- } else if (auto vecType = dyn_cast<VectorType>(type)) {
85+ if (auto vecType = dyn_cast<VectorType>(type)) {
8986 if (!vecType.getElementType ().isFloat ())
9087 return false ;
9188 // SPIRV distinguishes between vectors and matrices: OpenCL native math
@@ -170,8 +167,7 @@ void ConvertMathToXeVMPass::runOnOperation() {
170167 RewritePatternSet patterns (&getContext ());
171168 populateMathToXeVMConversionPatterns (patterns, convertArith);
172169 ConversionTarget target (getContext ());
173- target.addLegalDialect <BuiltinDialect, func::FuncDialect,
174- vector::VectorDialect, LLVM::LLVMDialect>();
170+ target.addLegalDialect <BuiltinDialect, LLVM::LLVMDialect>();
175171 if (failed (
176172 applyPartialConversion (getOperation (), target, std::move (patterns))))
177173 signalPassFailure ();
0 commit comments