Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ template <typename SourceOp>
class FIROpConversion : public ConvertFIRToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor<
mlir::ArrayRef<mlir::ValueRange>>;

explicit FIROpConversion(const LLVMTypeConverter &typeConverter,
const fir::FIRToLLVMPassOptions &options,
Expand All @@ -209,6 +211,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
rewrite(mlir::cast<SourceOp>(op),
OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
}
void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto sourceOp = llvm::cast<SourceOp>(op);
rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp),
rewriter);
}
llvm::LogicalResult match(mlir::Operation *op) const final {
return match(mlir::cast<SourceOp>(op));
}
Expand All @@ -219,7 +227,14 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
OpAdaptor(operands, mlir::cast<SourceOp>(op)),
rewriter);
}

llvm::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::ArrayRef<mlir::ValueRange> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto sourceOp = mlir::cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual llvm::LogicalResult match(SourceOp op) const {
Expand All @@ -229,6 +244,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
mlir::ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
llvm::SmallVector<mlir::Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand All @@ -237,6 +258,13 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
rewrite(op, adaptor, rewriter);
return mlir::success();
}
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
llvm::SmallVector<mlir::Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConvertFIRToLLVMPattern::matchAndRewrite;
Expand Down
10 changes: 7 additions & 3 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// TODO: This is a 1:N conversion. The conversion value mapping does not
// store such materializations yet. If the types of the most recently
// mapped values do not match, build a target materialization.
if (TypeRange(unpacked) == legalTypes) {
ValueRange unpackedRange(unpacked);
if (TypeRange(unpackedRange) == legalTypes) {
remapped.push_back(std::move(unpacked));
continue;
}
Expand Down Expand Up @@ -1677,7 +1678,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
SmallVector<ValueRange> newVals;
for (int i = 0; i < newValues.size(); ++i)
for (size_t i = 0; i < newValues.size(); ++i)
newVals.push_back(newValues.slice(i, 1));
impl->notifyOpReplaced(op, newVals);
}
Expand Down Expand Up @@ -2669,8 +2670,11 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
break;
}
if (!newMaterialization.empty()) {
assert(TypeRange(newMaterialization) == op.getResultTypes() &&
#ifndef NDEBUG
ValueRange newMaterializationRange(newMaterialization);
assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
"materialization callback produced value of incorrect type");
#endif // NDEBUG
rewriter.replaceOp(op, newMaterialization);
return success();
}
Expand Down