Skip to content

Commit 58389b2

Browse files
[mlir] Fix build after #116470 (#118147)
This should have been part of #116470.
1 parent 1f13713 commit 58389b2

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ template <typename SourceOp>
195195
class FIROpConversion : public ConvertFIRToLLVMPattern {
196196
public:
197197
using OpAdaptor = typename SourceOp::Adaptor;
198+
using OneToNOpAdaptor = typename SourceOp::template GenericAdaptor<
199+
mlir::ArrayRef<mlir::ValueRange>>;
198200

199201
explicit FIROpConversion(const LLVMTypeConverter &typeConverter,
200202
const fir::FIRToLLVMPassOptions &options,
@@ -209,6 +211,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
209211
rewrite(mlir::cast<SourceOp>(op),
210212
OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
211213
}
214+
void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands,
215+
mlir::ConversionPatternRewriter &rewriter) const final {
216+
auto sourceOp = llvm::cast<SourceOp>(op);
217+
rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp),
218+
rewriter);
219+
}
212220
llvm::LogicalResult match(mlir::Operation *op) const final {
213221
return match(mlir::cast<SourceOp>(op));
214222
}
@@ -219,7 +227,14 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
219227
OpAdaptor(operands, mlir::cast<SourceOp>(op)),
220228
rewriter);
221229
}
222-
230+
llvm::LogicalResult
231+
matchAndRewrite(mlir::Operation *op,
232+
mlir::ArrayRef<mlir::ValueRange> operands,
233+
mlir::ConversionPatternRewriter &rewriter) const final {
234+
auto sourceOp = mlir::cast<SourceOp>(op);
235+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
236+
rewriter);
237+
}
223238
/// Rewrite and Match methods that operate on the SourceOp type. These must be
224239
/// overridden by the derived pattern class.
225240
virtual llvm::LogicalResult match(SourceOp op) const {
@@ -229,6 +244,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
229244
mlir::ConversionPatternRewriter &rewriter) const {
230245
llvm_unreachable("must override rewrite or matchAndRewrite");
231246
}
247+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
248+
mlir::ConversionPatternRewriter &rewriter) const {
249+
llvm::SmallVector<mlir::Value> oneToOneOperands =
250+
getOneToOneAdaptorOperands(adaptor.getOperands());
251+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
252+
}
232253
virtual llvm::LogicalResult
233254
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
234255
mlir::ConversionPatternRewriter &rewriter) const {
@@ -237,6 +258,13 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
237258
rewrite(op, adaptor, rewriter);
238259
return mlir::success();
239260
}
261+
virtual llvm::LogicalResult
262+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
263+
mlir::ConversionPatternRewriter &rewriter) const {
264+
llvm::SmallVector<mlir::Value> oneToOneOperands =
265+
getOneToOneAdaptorOperands(adaptor.getOperands());
266+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
267+
}
240268

241269
private:
242270
using ConvertFIRToLLVMPattern::matchAndRewrite;

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12041204
// TODO: This is a 1:N conversion. The conversion value mapping does not
12051205
// store such materializations yet. If the types of the most recently
12061206
// mapped values do not match, build a target materialization.
1207-
if (TypeRange(unpacked) == legalTypes) {
1207+
ValueRange unpackedRange(unpacked);
1208+
if (TypeRange(unpackedRange) == legalTypes) {
12081209
remapped.push_back(std::move(unpacked));
12091210
continue;
12101211
}
@@ -1677,7 +1678,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
16771678
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
16781679
});
16791680
SmallVector<ValueRange> newVals;
1680-
for (int i = 0; i < newValues.size(); ++i)
1681+
for (size_t i = 0; i < newValues.size(); ++i)
16811682
newVals.push_back(newValues.slice(i, 1));
16821683
impl->notifyOpReplaced(op, newVals);
16831684
}
@@ -2669,8 +2670,11 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
26692670
break;
26702671
}
26712672
if (!newMaterialization.empty()) {
2672-
assert(TypeRange(newMaterialization) == op.getResultTypes() &&
2673+
#ifndef NDEBUG
2674+
ValueRange newMaterializationRange(newMaterialization);
2675+
assert(TypeRange(newMaterializationRange) == op.getResultTypes() &&
26732676
"materialization callback produced value of incorrect type");
2677+
#endif // NDEBUG
26742678
rewriter.replaceOp(op, newMaterialization);
26752679
return success();
26762680
}

0 commit comments

Comments
 (0)