Skip to content

Commit ff12413

Browse files
[WIP] 1:N conversion pattern
do not build argument materializations anymore fix more tests Fix decompose call graph test
1 parent d0d0632 commit ff12413

File tree

11 files changed

+663
-454
lines changed

11 files changed

+663
-454
lines changed

mlir/artifacts/jq-linux64

3.77 MB
Binary file not shown.

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ template <typename SourceOp>
143143
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144144
public:
145145
using OpAdaptor = typename SourceOp::Adaptor;
146+
using OneToNOpAdaptor =
147+
typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
146148

147149
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
148150
PatternBenefit benefit = 1)
@@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153155
/// Wrappers around the RewritePattern methods that pass the derived op type.
154156
void rewrite(Operation *op, ArrayRef<Value> operands,
155157
ConversionPatternRewriter &rewriter) const final {
156-
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157-
rewriter);
158+
auto sourceOp = cast<SourceOp>(op);
159+
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160+
}
161+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
162+
ConversionPatternRewriter &rewriter) const final {
163+
auto sourceOp = cast<SourceOp>(op);
164+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
158165
}
159166
LogicalResult match(Operation *op) const final {
160167
return match(cast<SourceOp>(op));
161168
}
162169
LogicalResult
163170
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
164171
ConversionPatternRewriter &rewriter) const final {
165-
return matchAndRewrite(cast<SourceOp>(op),
166-
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
172+
auto sourceOp = cast<SourceOp>(op);
173+
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174+
}
175+
LogicalResult
176+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
177+
ConversionPatternRewriter &rewriter) const final {
178+
auto sourceOp = cast<SourceOp>(op);
179+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180+
rewriter);
167181
}
168182

169183
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
175189
ConversionPatternRewriter &rewriter) const {
176190
llvm_unreachable("must override rewrite or matchAndRewrite");
177191
}
192+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193+
ConversionPatternRewriter &rewriter) const {
194+
SmallVector<Value> oneToOneOperands =
195+
getOneToOneAdaptorOperands(adaptor.getOperands());
196+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197+
}
178198
virtual LogicalResult
179199
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
180200
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
183203
rewrite(op, adaptor, rewriter);
184204
return success();
185205
}
206+
virtual LogicalResult
207+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208+
ConversionPatternRewriter &rewriter) const {
209+
SmallVector<Value> oneToOneOperands =
210+
getOneToOneAdaptorOperands(adaptor.getOperands());
211+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212+
}
186213

187214
private:
188215
using ConvertToLLVMPattern::match;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
537537
ConversionPatternRewriter &rewriter) const {
538538
llvm_unreachable("unimplemented rewrite");
539539
}
540+
virtual void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
541+
ConversionPatternRewriter &rewriter) const {
542+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
543+
}
540544

541545
/// Hook for derived classes to implement combined matching and rewriting.
542546
virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
547551
rewrite(op, operands, rewriter);
548552
return success();
549553
}
554+
virtual LogicalResult
555+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
556+
ConversionPatternRewriter &rewriter) const {
557+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
558+
}
550559

551560
/// Attempt to match and rewrite the IR root at the specified operation.
552561
LogicalResult matchAndRewrite(Operation *op,
@@ -574,6 +583,9 @@ class ConversionPattern : public RewritePattern {
574583
: RewritePattern(std::forward<Args>(args)...),
575584
typeConverter(&typeConverter) {}
576585

586+
SmallVector<Value>
587+
getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands) const;
588+
577589
protected:
578590
/// An optional type converter for use by this pattern.
579591
const TypeConverter *typeConverter = nullptr;
@@ -589,6 +601,8 @@ template <typename SourceOp>
589601
class OpConversionPattern : public ConversionPattern {
590602
public:
591603
using OpAdaptor = typename SourceOp::Adaptor;
604+
using OneToNOpAdaptor =
605+
typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
592606

593607
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
594608
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +621,24 @@ class OpConversionPattern : public ConversionPattern {
607621
auto sourceOp = cast<SourceOp>(op);
608622
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
609623
}
624+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
625+
ConversionPatternRewriter &rewriter) const final {
626+
auto sourceOp = cast<SourceOp>(op);
627+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
628+
}
610629
LogicalResult
611630
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
612631
ConversionPatternRewriter &rewriter) const final {
613632
auto sourceOp = cast<SourceOp>(op);
614633
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
615634
}
635+
LogicalResult
636+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
637+
ConversionPatternRewriter &rewriter) const final {
638+
auto sourceOp = cast<SourceOp>(op);
639+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
640+
rewriter);
641+
}
616642

617643
/// Rewrite and Match methods that operate on the SourceOp type. These must be
618644
/// overridden by the derived pattern class.
@@ -623,6 +649,12 @@ class OpConversionPattern : public ConversionPattern {
623649
ConversionPatternRewriter &rewriter) const {
624650
llvm_unreachable("must override matchAndRewrite or a rewrite method");
625651
}
652+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
653+
ConversionPatternRewriter &rewriter) const {
654+
SmallVector<Value> oneToOneOperands =
655+
getOneToOneAdaptorOperands(adaptor.getOperands());
656+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
657+
}
626658
virtual LogicalResult
627659
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
628660
ConversionPatternRewriter &rewriter) const {
@@ -631,6 +663,13 @@ class OpConversionPattern : public ConversionPattern {
631663
rewrite(op, adaptor, rewriter);
632664
return success();
633665
}
666+
virtual LogicalResult
667+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
668+
ConversionPatternRewriter &rewriter) const {
669+
SmallVector<Value> oneToOneOperands =
670+
getOneToOneAdaptorOperands(adaptor.getOperands());
671+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
672+
}
634673

635674
private:
636675
using ConversionPattern::matchAndRewrite;
@@ -656,18 +695,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656695
ConversionPatternRewriter &rewriter) const final {
657696
rewrite(cast<SourceOp>(op), operands, rewriter);
658697
}
698+
void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
699+
ConversionPatternRewriter &rewriter) const final {
700+
rewrite(cast<SourceOp>(op), operands, rewriter);
701+
}
659702
LogicalResult
660703
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
661704
ConversionPatternRewriter &rewriter) const final {
662705
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
663706
}
707+
LogicalResult
708+
matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
709+
ConversionPatternRewriter &rewriter) const final {
710+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
711+
}
664712

665713
/// Rewrite and Match methods that operate on the SourceOp type. These must be
666714
/// overridden by the derived pattern class.
667715
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
668716
ConversionPatternRewriter &rewriter) const {
669717
llvm_unreachable("must override matchAndRewrite or a rewrite method");
670718
}
719+
virtual void rewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
720+
ConversionPatternRewriter &rewriter) const {
721+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
722+
}
671723
virtual LogicalResult
672724
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
673725
ConversionPatternRewriter &rewriter) const {
@@ -676,6 +728,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676728
rewrite(op, operands, rewriter);
677729
return success();
678730
}
731+
virtual LogicalResult
732+
matchAndRewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
733+
ConversionPatternRewriter &rewriter) const {
734+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
735+
}
679736

680737
private:
681738
using ConversionPattern::matchAndRewrite;

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
/*
156157
// Argument materializations convert from the new block argument types
157158
// (multiple SSA values that make up a memref descriptor) back to the
158159
// original block argument type. The dialect conversion framework will then
@@ -198,16 +199,62 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
198199
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
199200
.getResult(0);
200201
});
202+
203+
*/
201204
// Add generic source and target materializations to handle cases where
202205
// non-LLVM types persist after an LLVM conversion.
203206
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
204207
ValueRange inputs, Location loc) {
205-
if (inputs.size() != 1)
206-
return Value();
208+
//if (inputs.size() != 1)
209+
// return Value();
207210

208211
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
209212
.getResult(0);
210213
});
214+
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
215+
ValueRange inputs, Location loc) {
216+
if (inputs.size()== 1 && isa<LLVM::LLVMStructType>(inputs.front().getType())) return Value();
217+
218+
Value desc;
219+
if (inputs.size() == 1 && isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
220+
// This is a bare pointer. We allow bare pointers only for function entry
221+
// blocks.
222+
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
223+
if (!barePtr)
224+
return Value();
225+
Block *block = barePtr.getOwner();
226+
if (!block->isEntryBlock() ||
227+
!isa<FunctionOpInterface>(block->getParentOp()))
228+
return Value();
229+
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
230+
inputs[0]);
231+
} else {
232+
//llvm::errs() << "pack elems: " << inputs.size() << "\n";
233+
//llvm::errs() << inputs[0] << "\n";
234+
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
235+
//llvm::errs() << "done packing\n";
236+
}
237+
// An argument materialization must return a value of type `resultType`,
238+
// so insert a cast from the memref descriptor type (!llvm.struct) to the
239+
// original memref type.
240+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
241+
.getResult(0);
242+
});
243+
addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType,
244+
ValueRange inputs, Location loc) {
245+
if (inputs.size() == 1) {
246+
// Bare pointers are not supported for unranked memrefs because a
247+
// memref descriptor cannot be built just from a bare pointer.
248+
return Value();
249+
}
250+
Value desc =
251+
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
252+
// An argument materialization must return a value of type
253+
// `resultType`, so insert a cast from the memref descriptor type
254+
// (!llvm.struct) to the original memref type.
255+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
256+
.getResult(0);
257+
});
211258
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
212259
ValueRange inputs, Location loc) {
213260
if (inputs.size() != 1)
@@ -216,6 +263,51 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
216263
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
217264
.getResult(0);
218265
});
266+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
267+
ValueRange inputs,
268+
Location loc, Type originalType) -> Value {
269+
llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
270+
if (!originalType) {
271+
llvm::errs() << " -- no orig\n";
272+
return Value();
273+
}
274+
if (auto memrefType = dyn_cast<MemRefType>(originalType)) {
275+
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
276+
if (inputs.size() == 1) {
277+
Value input = inputs.front();
278+
if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
279+
if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
280+
input = castOp.getInputs()[0];
281+
}
282+
}
283+
if (!isa<LLVM::LLVMPointerType>(input.getType()))
284+
return Value();
285+
BlockArgument barePtr = dyn_cast<BlockArgument>(input);
286+
if (!barePtr)
287+
return Value();
288+
Block *block = barePtr.getOwner();
289+
if (!block->isEntryBlock() ||
290+
!isa<FunctionOpInterface>(block->getParentOp()))
291+
return Value();
292+
// Bare ptr
293+
return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
294+
input);
295+
}
296+
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
297+
}
298+
if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
299+
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
300+
if (inputs.size() == 1) {
301+
// Bare pointers are not supported for unranked memrefs because a
302+
// memref descriptor cannot be built just from a bare pointer.
303+
return Value();
304+
}
305+
return UnrankedMemRefDescriptor::pack(builder, loc, *this,
306+
memrefType, inputs);
307+
}
308+
309+
return Value();
310+
});
219311

220312
// Integer memory spaces map to themselves.
221313
addTypeAttributeConversion(

0 commit comments

Comments
 (0)