Skip to content

Commit aefb5e7

Browse files
AdUhTkJmlanza
authored andcommitted
[CIR][NFC] Generalize IdiomRecognizer (#1484)
The comments suggested that we should use TableGen to generate the recognizing functions. However, I think templates might be more suitable for generating them -- and I can't find any existing TableGen backends that let us generate arbitrary functions. My choice of design is to offer a template to match standard library functions: ```cpp // matches std::find with 3 arguments, and raise it into StdFindOp StdRecognizer<3, StdFindOp, StdFuncsID::Find> ``` I have to use a TableGen'd enum to map names to IDs, as we can't pass string literals to template arguments easily in C++17. This also constraints design of future `StdXXXOp`s: they must take operands the same way of StdFindOp, where the first one is the original function, and the rest are function arguments. I'm not sure if this approach is the best way. Please tell me if you have concerns or any alternative ways.
1 parent 28b0986 commit aefb5e7

File tree

5 files changed

+131
-103
lines changed

5 files changed

+131
-103
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 6 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4700,72 +4700,6 @@ def FrameAddrOp : FuncAddrBuiltinOp<"frame_address"> {
47004700
}];
47014701
}
47024702

4703-
//===----------------------------------------------------------------------===//
4704-
// StdFindOp
4705-
//===----------------------------------------------------------------------===//
4706-
4707-
def StdFindOp : CIR_Op<"std.find", [SameFirstSecondOperandAndResultType]> {
4708-
let arguments = (ins FlatSymbolRefAttr:$original_fn,
4709-
CIR_AnyType:$first,
4710-
CIR_AnyType:$last,
4711-
CIR_AnyType:$pattern);
4712-
let summary = "std:find()";
4713-
let results = (outs CIR_AnyType:$result);
4714-
4715-
let description = [{
4716-
Search for `pattern` in data range from `first` to `last`. This currently
4717-
maps to only one form of `std::find`. The `original_fn` operand tracks the
4718-
mangled named that can be used when lowering to a `cir.call`.
4719-
4720-
Example:
4721-
4722-
```mlir
4723-
...
4724-
%result = cir.std.find(@original_fn,
4725-
%first : !T, %last : !T, %pattern : !P) -> !T
4726-
```
4727-
}];
4728-
4729-
let assemblyFormat = [{
4730-
`(`
4731-
$original_fn
4732-
`,` $first `:` type($first)
4733-
`,` $last `:` type($last)
4734-
`,` $pattern `:` type($pattern)
4735-
`)` `->` type($result) attr-dict
4736-
}];
4737-
let hasVerifier = 0;
4738-
}
4739-
4740-
//===----------------------------------------------------------------------===//
4741-
// IterBegin/End
4742-
//===----------------------------------------------------------------------===//
4743-
4744-
def IterBeginOp : CIR_Op<"iterator_begin"> {
4745-
let arguments = (ins FlatSymbolRefAttr:$original_fn, CIR_AnyType:$container);
4746-
let summary = "Returns an iterator to the first element of a container";
4747-
let results = (outs CIR_AnyType:$result);
4748-
let assemblyFormat = [{
4749-
`(`
4750-
$original_fn `,` $container `:` type($container)
4751-
`)` `->` type($result) attr-dict
4752-
}];
4753-
let hasVerifier = 0;
4754-
}
4755-
4756-
def IterEndOp : CIR_Op<"iterator_end"> {
4757-
let arguments = (ins FlatSymbolRefAttr:$original_fn, CIR_AnyType:$container);
4758-
let summary = "Returns an iterator to the element following the last element"
4759-
" of a container";
4760-
let results = (outs CIR_AnyType:$result);
4761-
let assemblyFormat = [{
4762-
`(`
4763-
$original_fn `,` $container `:` type($container)
4764-
`)` `->` type($result) attr-dict
4765-
}];
4766-
let hasVerifier = 0;
4767-
}
4768-
47694703
//===----------------------------------------------------------------------===//
47704704
// Floating Point Ops
47714705
//===----------------------------------------------------------------------===//
@@ -5755,4 +5689,10 @@ def SignBitOp : CIR_Op<"signbit", [Pure]> {
57555689
}];
57565690
}
57575691

5692+
//===----------------------------------------------------------------------===//
5693+
// Standard library function calls
5694+
//===----------------------------------------------------------------------===//
5695+
5696+
include "clang/CIR/Dialect/IR/CIRStdOps.td"
5697+
57585698
#endif // LLVM_CLANG_CIR_DIALECT_IR_CIROPS
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//===-- CIRStdOps.td - CIR standard library ops ------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// Defines ops representing standard library calls
10+
///
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRSTDOPS
14+
#define LLVM_CLANG_CIR_DIALECT_IR_CIRSTDOPS
15+
16+
class CIRStdOp<string functionName, dag args, dag res, list<Trait> traits = []>:
17+
CIR_Op<"std." # functionName, traits> {
18+
string funcName = functionName;
19+
20+
let arguments = !con((ins FlatSymbolRefAttr:$original_fn), args);
21+
22+
let summary = "std::" # functionName # "()";
23+
let results = res;
24+
25+
let extraClassDeclaration = [{
26+
static constexpr unsigned getNumArgs() {
27+
return }] # !size(args) # [{;
28+
}
29+
static llvm::StringRef getFunctionName() {
30+
return "}] # functionName # [{";
31+
}
32+
}];
33+
34+
string argsAssemblyFormat = !interleave(
35+
!foreach(
36+
name,
37+
!foreach(i, !range(!size(args)), !getdagname(args, i)),
38+
!strconcat("$", name, " `:` type($", name, ")")
39+
), " `,` "
40+
);
41+
42+
string resultAssemblyFormat = !if(
43+
!empty(res),
44+
"",
45+
" `->` type($" # !getdagname(res, 0) # ")"
46+
);
47+
48+
let assemblyFormat = !strconcat("`(` ", argsAssemblyFormat,
49+
" `,` $original_fn `)`", resultAssemblyFormat,
50+
" attr-dict");
51+
52+
let hasVerifier = 0;
53+
}
54+
55+
def StdFindOp : CIRStdOp<"find",
56+
(ins CIR_AnyType:$first, CIR_AnyType:$last, CIR_AnyType:$pattern),
57+
(outs CIR_AnyType:$result),
58+
[SameFirstSecondOperandAndResultType]>;
59+
def IterBeginOp: CIRStdOp<"begin",
60+
(ins CIR_AnyType:$container),
61+
(outs CIR_AnyType:$result)>;
62+
def IterEndOp: CIRStdOp<"end",
63+
(ins CIR_AnyType:$container),
64+
(outs CIR_AnyType:$result)>;
65+
66+
#endif

clang/lib/CIR/Dialect/Transforms/IdiomRecognizer.cpp

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,52 @@ using namespace cir;
3232

3333
namespace {
3434

35+
// Recognizes a cir.call that calls a standard library function represented
36+
// by `TargetOp`, and raise it to that operation.
37+
template <typename TargetOp> class StdRecognizer {
38+
private:
39+
// Reserved for template specialization.
40+
static bool checkArguments(mlir::ValueRange) { return true; }
41+
42+
template <size_t... Indices>
43+
static TargetOp buildCall(CIRBaseBuilderTy &builder, CallOp call,
44+
std::index_sequence<Indices...>) {
45+
return builder.create<TargetOp>(call.getLoc(), call.getResult().getType(),
46+
call.getCalleeAttr(),
47+
call.getOperand(Indices)...);
48+
}
49+
50+
public:
51+
static bool raise(CallOp call, mlir::MLIRContext &context, bool remark) {
52+
constexpr int numArgs = TargetOp::getNumArgs();
53+
if (call.getNumOperands() != numArgs)
54+
return false;
55+
56+
auto callExprAttr = call.getAstAttr();
57+
llvm::StringRef stdFuncName = TargetOp::getFunctionName();
58+
if (!callExprAttr || !callExprAttr.isStdFunctionCall(stdFuncName))
59+
return false;
60+
61+
if (!checkArguments(call.getArgOperands()))
62+
return false;
63+
64+
if (remark)
65+
mlir::emitRemark(call.getLoc())
66+
<< "found call to std::" << stdFuncName << "()";
67+
68+
CIRBaseBuilderTy builder(context);
69+
builder.setInsertionPointAfter(call.getOperation());
70+
TargetOp op = buildCall(builder, call, std::make_index_sequence<numArgs>());
71+
call.replaceAllUsesWith(op);
72+
call.erase();
73+
return true;
74+
}
75+
};
76+
3577
struct IdiomRecognizerPass : public IdiomRecognizerBase<IdiomRecognizerPass> {
3678
IdiomRecognizerPass() = default;
3779
void runOnOperation() override;
3880
void recognizeCall(CallOp call);
39-
bool raiseStdFind(CallOp call);
4081
bool raiseIteratorBeginEnd(CallOp call);
4182

4283
// Handle pass options
@@ -88,30 +129,6 @@ struct IdiomRecognizerPass : public IdiomRecognizerBase<IdiomRecognizerPass> {
88129
};
89130
} // namespace
90131

91-
bool IdiomRecognizerPass::raiseStdFind(CallOp call) {
92-
// FIXME: tablegen all of this function.
93-
if (call.getNumOperands() != 3)
94-
return false;
95-
96-
auto callExprAttr = call.getAstAttr();
97-
if (!callExprAttr || !callExprAttr.isStdFunctionCall("find")) {
98-
return false;
99-
}
100-
101-
if (opts.emitRemarkFoundCalls())
102-
emitRemark(call.getLoc()) << "found call to std::find()";
103-
104-
CIRBaseBuilderTy builder(getContext());
105-
builder.setInsertionPointAfter(call.getOperation());
106-
auto findOp = builder.create<cir::StdFindOp>(
107-
call.getLoc(), call.getResult().getType(), call.getCalleeAttr(),
108-
call.getOperand(0), call.getOperand(1), call.getOperand(2));
109-
110-
call.replaceAllUsesWith(findOp);
111-
call.erase();
112-
return true;
113-
}
114-
115132
static bool isIteratorLikeType(mlir::Type t) {
116133
// TODO: some iterators are going to be represented with structs,
117134
// in which case we could look at ASTRecordDeclInterface for more
@@ -175,8 +192,16 @@ void IdiomRecognizerPass::recognizeCall(CallOp call) {
175192
if (raiseIteratorBeginEnd(call))
176193
return;
177194

178-
if (raiseStdFind(call))
179-
return;
195+
bool remark = opts.emitRemarkFoundCalls();
196+
197+
using StdFunctionsRecognizer = std::tuple<StdRecognizer<StdFindOp>>;
198+
199+
// MSVC requires explicitly capturing these variables.
200+
std::apply(
201+
[&, call, remark, this](auto... recognizers) {
202+
(decltype(recognizers)::raise(call, this->getContext(), remark) || ...);
203+
},
204+
StdFunctionsRecognizer());
180205
}
181206

182207
void IdiomRecognizerPass::runOnOperation() {

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,8 +1481,7 @@ void LoweringPreparePass::lowerIterBeginOp(IterBeginOp op) {
14811481
CIRBaseBuilderTy builder(getContext());
14821482
builder.setInsertionPointAfter(op.getOperation());
14831483
auto call = builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(),
1484-
op.getResult().getType(),
1485-
mlir::ValueRange{op.getOperand()});
1484+
op.getResult().getType(), op.getOperand());
14861485

14871486
op.replaceAllUsesWith(call);
14881487
op.erase();
@@ -1492,8 +1491,7 @@ void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
14921491
CIRBaseBuilderTy builder(getContext());
14931492
builder.setInsertionPointAfter(op.getOperation());
14941493
auto call = builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(),
1495-
op.getResult().getType(),
1496-
mlir::ValueRange{op.getOperand()});
1494+
op.getResult().getType(), op.getOperand());
14971495

14981496
op.replaceAllUsesWith(call);
14991497
op.erase();

clang/test/CIR/Transforms/idiom-recognizer.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ int test_find(unsigned char n = 3)
1818
// expected-remark@-2 {{found call to end() iterator}}
1919

2020
// BEFORE-IDIOM: {{.*}} cir.call @_ZNSt5arrayIhLj9EE5beginEv(
21-
// AFTER-IDIOM: {{.*}} cir.iterator_begin(@_ZNSt5arrayIhLj9EE5beginEv,
21+
// AFTER-IDIOM: {{.*}} cir.std.begin({{.*}}, @_ZNSt5arrayIhLj9EE5beginEv
2222
// AFTER-LOWERING-PREPARE: {{.*}} cir.call @_ZNSt5arrayIhLj9EE5beginEv(
2323

2424
// BEFORE-IDIOM: {{.*}} cir.call @_ZNSt5arrayIhLj9EE3endEv(
25-
// AFTER-IDIOM: {{.*}} cir.iterator_end(@_ZNSt5arrayIhLj9EE3endEv,
25+
// AFTER-IDIOM: {{.*}} cir.std.end({{.*}}, @_ZNSt5arrayIhLj9EE3endEv
2626
// AFTER-LOWERING-PREPARE: {{.*}} cir.call @_ZNSt5arrayIhLj9EE3endEv(
2727

2828
// BEFORE-IDIOM: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(
29-
// AFTER-IDIOM: {{.*}} cir.std.find(@_ZSt4findIPhhET_S1_S1_RKT0_,
29+
// AFTER-IDIOM: {{.*}} cir.std.find({{.*}}, @_ZSt4findIPhhET_S1_S1_RKT0_
3030
// AFTER-LOWERING-PREPARE: {{.*}} cir.call @_ZSt4findIPhhET_S1_S1_RKT0_(
3131

3232
if (f != v.end()) // expected-remark {{found call to end() iterator}}
@@ -43,8 +43,7 @@ template<typename T, unsigned N> struct array {
4343
};
4444
}
4545

46-
int iter_test()
47-
{
46+
void iter_test() {
4847
yolo::array<unsigned char, 3> v = {1, 2, 3};
4948
(void)v.begin(); // no remark should be produced.
50-
}
49+
}

0 commit comments

Comments
 (0)