Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added mlir/artifacts/jq-linux64
Binary file not shown.
35 changes: 31 additions & 4 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;

explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
Expand All @@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
rewriter);
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
Expand All @@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewrite(op, adaptor, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConvertToLLVMPattern::match;
Expand Down
81 changes: 79 additions & 2 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
virtual void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

/// Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult
Expand All @@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
rewrite(op, operands, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult matchAndRewrite(Operation *op,
Expand Down Expand Up @@ -574,6 +583,9 @@ class ConversionPattern : public RewritePattern {
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}

SmallVector<Value>
getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands) const;

protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
Expand All @@ -589,6 +601,8 @@ template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;

OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
Expand All @@ -607,12 +621,24 @@ class OpConversionPattern : public ConversionPattern {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
Expand All @@ -623,6 +649,12 @@ class OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -631,6 +663,13 @@ class OpConversionPattern : public ConversionPattern {
rewrite(op, adaptor, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConversionPattern::matchAndRewrite;
Expand All @@ -656,18 +695,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -676,6 +728,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
rewrite(op, operands, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

private:
using ConversionPattern::matchAndRewrite;
Expand Down Expand Up @@ -795,12 +852,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }

/// PatternRewriter hook for replacing an operation.
/// Replace the given operation with the new values. The number of op results
/// and replacement values must match. The types may differ: the dialect
/// conversion driver will reconcile any surviving type mismatches at the end
/// of the conversion process with source materializations. The given
/// operation is erased.
void replaceOp(Operation *op, ValueRange newValues) override;

/// PatternRewriter hook for replacing an operation.
/// Replace the given operation with the results of the new op. The number of
/// op results must match. The types may differ: the dialect conversion
/// driver will reconcile any surviving type mismatches at the end of the
/// conversion process with source materializations. The original operation
/// is erased.
void replaceOp(Operation *op, Operation *newOp) override;

/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. If an original SSA value is replaced
/// by multiple SSA values (i.e., a value range has more than 1 element), the
/// conversion driver will insert an argument materialization to convert the
/// N SSA values back into 1 SSA value of the original type. The given
/// operation is erased.
///
/// Note: The argument materialization is a workaround until we have full 1:N
/// support in the dialect conversion. (It is going to disappear from both
/// `replaceOpWithMultiple` and `applySignatureConversion`.)
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);

/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
/// otherwise an assert will be issued.
Expand Down
96 changes: 94 additions & 2 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});

/*
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
Expand Down Expand Up @@ -198,16 +199,62 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});

*/
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1)
return Value();
//if (inputs.size() != 1)
// return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
if (inputs.size()== 1 && isa<LLVM::LLVMStructType>(inputs.front().getType())) return Value();

Value desc;
if (inputs.size() == 1 && isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
if (!barePtr)
return Value();
Block *block = barePtr.getOwner();
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
return Value();
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
} else {
//llvm::errs() << "pack elems: " << inputs.size() << "\n";
//llvm::errs() << inputs[0] << "\n";
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
//llvm::errs() << "done packing\n";
}
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
return Value();
}
Value desc =
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1)
Expand All @@ -216,6 +263,51 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc, Type originalType) -> Value {
llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
if (!originalType) {
llvm::errs() << " -- no orig\n";
return Value();
}
if (auto memrefType = dyn_cast<MemRefType>(originalType)) {
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
if (inputs.size() == 1) {
Value input = inputs.front();
if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
input = castOp.getInputs()[0];
}
}
if (!isa<LLVM::LLVMPointerType>(input.getType()))
return Value();
BlockArgument barePtr = dyn_cast<BlockArgument>(input);
if (!barePtr)
return Value();
Block *block = barePtr.getOwner();
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
return Value();
// Bare ptr
return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
input);
}
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
}
if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
return Value();
}
return UnrankedMemRefDescriptor::pack(builder, loc, *this,
memrefType, inputs);
}

return Value();
});

// Integer memory spaces map to themselves.
addTypeAttributeConversion(
Expand Down
Loading
Loading