diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 4f8301f9380b8..25d9c404f0181 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -83,7 +83,10 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure, } }]; let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; } @@ -404,7 +407,10 @@ def DotOp : AVX_LowOp<"dot", [Pure, } }]; let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; } @@ -452,7 +458,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>, }]; let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; } @@ -500,7 +509,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo }]; let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; } @@ -543,7 +555,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory }]; let extraClassDeclaration = [{ - SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); }]; } #endif // X86VECTOR_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td index 5176f4a447b6e..cde9d1dce65ee 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td @@ -58,9 +58,11 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> { }], /*retType=*/"SmallVector", /*methodName=*/"getIntrinsicOperands", - /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter), + /*args=*/(ins "::mlir::ArrayRef":$operands, + "const ::mlir::LLVMTypeConverter &":$typeConverter, + "::mlir::RewriterBase &":$rewriter), /*methodBody=*/"", - /*defaultImplementation=*/"return SmallVector($_op->getOperands());" + /*defaultImplementation=*/"return SmallVector(operands);" >, ]; } diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index 8d383b1f8103b..cc7ab7f3f3895 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -31,24 +31,11 @@ void x86vector::X86VectorDialect::initialize() { >(); } -static SmallVector -getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal, - RewriterBase &rewriter, - const LLVMTypeConverter &typeConverter) { - SmallVector operands; - auto opType = memrefVal.getType(); - - Type llvmStructType = typeConverter.convertType(opType); - Value llvmStruct = - rewriter - .create(loc, llvmStructType, memrefVal) - .getResult(0); - MemRefDescriptor memRefDescriptor(llvmStruct); - - Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType); - operands.push_back(ptr); - - return operands; +static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + MemRefDescriptor memRefDescriptor(buffer); + return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type); } LogicalResult x86vector::MaskCompressOp::verify() { @@ -66,48 +53,61 @@ LogicalResult x86vector::MaskCompressOp::verify() { } SmallVector x86vector::MaskCompressOp::getIntrinsicOperands( - RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { auto loc = getLoc(); + Adaptor adaptor(operands, *this); - auto opType = getA().getType(); + auto opType = adaptor.getA().getType(); Value src; - if (getSrc()) { - src = getSrc(); - } else if (getConstantSrc()) { - src = rewriter.create(loc, opType, getConstantSrcAttr()); + if (adaptor.getSrc()) { + src = adaptor.getSrc(); + } else if (adaptor.getConstantSrc()) { + src = rewriter.create(loc, opType, + adaptor.getConstantSrcAttr()); } else { auto zeroAttr = rewriter.getZeroAttr(opType); src = rewriter.create(loc, opType, zeroAttr); } - return SmallVector{getA(), src, getK()}; + return SmallVector{adaptor.getA(), src, adaptor.getK()}; } SmallVector -x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter, - const LLVMTypeConverter &typeConverter) { - SmallVector operands(getOperands()); +x86vector::DotOp::getIntrinsicOperands(ArrayRef operands, + const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + SmallVector intrinsicOperands(operands); // Dot product of all elements, broadcasted to all elements. Value scale = rewriter.create(getLoc(), rewriter.getI8Type(), 0xff); - operands.push_back(scale); + intrinsicOperands.push_back(scale); - return operands; + return intrinsicOperands; } SmallVector x86vector::BcstToPackedF32Op::getIntrinsicOperands( - RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { - return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + Adaptor adaptor(operands, *this); + return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(), + typeConverter, rewriter)}; } SmallVector x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands( - RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { - return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + Adaptor adaptor(operands, *this); + return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(), + typeConverter, rewriter)}; } SmallVector x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands( - RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { - return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + Adaptor adaptor(operands, *this); + return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(), + typeConverter, rewriter)}; } #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index 9ee44a63ba2e4..483c1f5c3e4c6 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -84,20 +84,23 @@ LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic, /// Generic one-to-one conversion of simply mappable operations into calls /// to their respective LLVM intrinsics. struct OneToOneIntrinsicOpConversion - : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern< - x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern; + : public OpInterfaceConversionPattern { + using OpInterfaceConversionPattern< + x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern; OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(&typeConverter.getContext(), benefit), + : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(), + benefit), typeConverter(typeConverter) {} - LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op, - PatternRewriter &rewriter) const override { - return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()), - op.getIntrinsicOperands(rewriter, typeConverter), - typeConverter, rewriter); + LogicalResult + matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + return intrinsicRewrite( + op, rewriter.getStringAttr(op.getIntrinsicName()), + op.getIntrinsicOperands(operands, typeConverter, rewriter), + typeConverter, rewriter); } private: