diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h index c820b83834de6..35749dae5d7e9 100644 --- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h +++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h @@ -195,6 +195,8 @@ template class FIROpConversion : public ConvertFIRToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor< + mlir::ArrayRef>; explicit FIROpConversion(const LLVMTypeConverter &typeConverter, const fir::FIRToLLVMPassOptions &options, @@ -209,6 +211,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern { rewrite(mlir::cast(op), OpAdaptor(operands, mlir::cast(op)), rewriter); } + void rewrite(mlir::Operation *op, mlir::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + auto sourceOp = llvm::cast(op); + rewrite(llvm::cast(op), OneToNOpAdaptor(operands, sourceOp), + rewriter); + } llvm::LogicalResult match(mlir::Operation *op) const final { return match(mlir::cast(op)); } @@ -219,7 +227,14 @@ class FIROpConversion : public ConvertFIRToLLVMPattern { OpAdaptor(operands, mlir::cast(op)), rewriter); } - + llvm::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + auto sourceOp = mlir::cast(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. virtual llvm::LogicalResult match(SourceOp op) const { @@ -229,6 +244,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern { mlir::ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override rewrite or matchAndRewrite"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + llvm::SmallVector oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual llvm::LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { @@ -237,6 +258,13 @@ class FIROpConversion : public ConvertFIRToLLVMPattern { rewrite(op, adaptor, rewriter); return mlir::success(); } + virtual llvm::LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + llvm::SmallVector oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConvertFIRToLLVMPattern::matchAndRewrite; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 613fd6d9d74b1..cedf645e2985d 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1204,7 +1204,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // TODO: This is a 1:N conversion. The conversion value mapping does not // store such materializations yet. If the types of the most recently // mapped values do not match, build a target materialization. - if (TypeRange(unpacked) == legalTypes) { + ValueRange unpackedRange(unpacked); + if (TypeRange(unpackedRange) == legalTypes) { remapped.push_back(std::move(unpacked)); continue; } @@ -1677,7 +1678,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); SmallVector newVals; - for (int i = 0; i < newValues.size(); ++i) + for (size_t i = 0; i < newValues.size(); ++i) newVals.push_back(newValues.slice(i, 1)); impl->notifyOpReplaced(op, newVals); } @@ -2669,8 +2670,11 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, break; } if (!newMaterialization.empty()) { - assert(TypeRange(newMaterialization) == op.getResultTypes() && +#ifndef NDEBUG + ValueRange newMaterializationRange(newMaterialization); + assert(TypeRange(newMaterializationRange) == op.getResultTypes() && "materialization callback produced value of incorrect type"); +#endif // NDEBUG rewriter.replaceOp(op, newMaterialization); return success(); }