Skip to content

Commit 55c4e43

Browse files
committed
Visitor: Add support for nested visitor clients
The idea here is that we can compose a complex lowering pass out of smaller "building blocks" that are all implemented in terms of visiting functions.
1 parent da31a16 commit 55c4e43

File tree

4 files changed

+255
-28
lines changed

4 files changed

+255
-28
lines changed

example/ExampleMain.cpp

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ using namespace llvm_dialects;
4242
enum class Action {
4343
Build,
4444
Verify,
45+
Visit,
4546
};
4647

4748
cl::opt<Action> g_action(
4849
cl::desc("Action to perform:"), cl::init(Action::Build),
4950
cl::values(clEnumValN(Action::Build, "build", "Example IRBuilder use"),
50-
clEnumValN(Action::Verify, "verify", "Verify an input module")));
51+
clEnumValN(Action::Verify, "verify", "Verify an input module"),
52+
clEnumValN(Action::Visit, "visit", "Example Visitor use")));
5153

5254
// Input sources
5355
cl::list<std::string> g_inputs(cl::Positional, cl::ZeroOrMore,
@@ -132,6 +134,54 @@ std::unique_ptr<Module> createModuleExampleTypedPtrs(LLVMContext &context) {
132134
return module;
133135
}
134136

137+
struct VisitorInnermost {
138+
int counter = 0;
139+
};
140+
141+
struct VisitorNest {
142+
raw_ostream *out = nullptr;
143+
VisitorInnermost inner;
144+
};
145+
146+
struct VisitorContainer {
147+
int padding;
148+
VisitorNest nest;
149+
};
150+
151+
template <>
152+
struct llvm_dialects::VisitorPayloadProjection<VisitorNest, raw_ostream> {
153+
static raw_ostream &project(VisitorNest &nest) { return *nest.out; }
154+
};
155+
156+
LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorContainer, nest)
157+
LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(VisitorNest, inner)
158+
159+
void exampleVisit(Module &module) {
160+
auto visitor =
161+
VisitorBuilder<VisitorContainer>()
162+
.nest<VisitorNest>([](VisitorBuilder<VisitorNest> &b) {
163+
b.add<xd::ReadOp>([](VisitorNest &self, xd::ReadOp &op) {
164+
*self.out << "visiting ReadOp: " << op << '\n';
165+
});
166+
b.nest<raw_ostream>([](VisitorBuilder<raw_ostream> &b) {
167+
b.add<xd::WriteOp>([](raw_ostream &out, xd::WriteOp &op) {
168+
out << "visiting WriteOp: " << op << '\n';
169+
});
170+
});
171+
b.nest<VisitorInnermost>([](VisitorBuilder<VisitorInnermost> &b) {
172+
b.add<xd::ITruncOp>([](VisitorInnermost &inner,
173+
xd::ITruncOp &op) { inner.counter++; });
174+
});
175+
})
176+
.build();
177+
178+
VisitorContainer container;
179+
container.nest.out = &outs();
180+
visitor.visit(container, module);
181+
182+
outs() << "inner.counter = " << container.nest.inner.counter << '\n';
183+
}
184+
135185
int main(int argc, char **argv) {
136186
llvm::cl::ParseCommandLineOptions(argc, argv);
137187

@@ -143,7 +193,7 @@ int main(int argc, char **argv) {
143193
if (g_action == Action::Build) {
144194
auto module = g_typedPointers ? createModuleExampleTypedPtrs(context) : createModuleExample(context);
145195
module->print(llvm::outs(), nullptr, false);
146-
} else if (g_action == Action::Verify) {
196+
} else {
147197
if (g_inputs.size() != 1) {
148198
errs() << "Need exactly one input module\n";
149199
return 1;
@@ -164,10 +214,14 @@ int main(int argc, char **argv) {
164214
return 1;
165215
}
166216

167-
if (!verify(*module, outs()))
168-
return 1;
169-
} else {
170-
report_fatal_error("unhandled action");
217+
if (g_action == Action::Verify) {
218+
if (!verify(*module, outs()))
219+
return 1;
220+
} else if (g_action == Action::Visit) {
221+
exampleVisit(*module);
222+
} else {
223+
report_fatal_error("unhandled action");
224+
}
171225
}
172226

173227
return 0;

include/llvm-dialects/Dialect/Visitor.h

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,37 +45,107 @@ enum class VisitorStrategy {
4545
ByFunctionDeclaration,
4646
};
4747

48+
/// @brief Auxiliary template to support nested visitor clients.
49+
///
50+
/// It can be convenient to write visitor clients in a modular fashion and
51+
/// combine them together to use a single @ref Visitor. Each client module
52+
/// may use a different payload type.
53+
///
54+
/// In this case, each nested client module's payload must be reachable from
55+
/// the top-level payload, and a specialization of this template must be
56+
/// provided that "projects" parent payload references to child payload
57+
/// references.
58+
///
59+
/// In cases where the nested payload is a field of the parent payload, the
60+
/// @ref LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD can be used instead.
61+
template <typename PayloadT, typename NestedPayloadT>
62+
struct VisitorPayloadProjection {
63+
// Template specializations must implement this static method:
64+
//
65+
// static NestedPayloadT &project(PayloadT &);
66+
};
67+
68+
/// Declare that `PayloadT` can be projected to a nested payload type via
69+
/// `field`.
70+
///
71+
/// This creates a specialization of @ref VisitorPayloadProjection and must
72+
/// therefore typically be outside of any namespace. The nested type is derived
73+
/// automatically.
74+
#define LLVM_DIALECTS_VISITOR_PAYLOAD_PROJECT_FIELD(PayloadT, field) \
75+
template <> \
76+
struct llvm_dialects::detail::VisitorPayloadOffsetProjection< \
77+
PayloadT, \
78+
std::remove_reference_t<decltype(std::declval<PayloadT>().field)>> { \
79+
static constexpr bool useOffsetof = true; \
80+
static constexpr std::size_t offset = offsetof(PayloadT, field); \
81+
};
82+
4883
namespace detail {
4984

5085
class VisitorBase;
5186

5287
using VisitorCallback = void (void *, void *, llvm::Instruction *);
53-
using VisitorCase = std::tuple<const OpDescription *, void *, VisitorCallback *>;
88+
using PayloadProjectionCallback = void *(void *);
89+
90+
/// Apply first the byte offset and then the projection function. If projection
91+
/// is null, stop the projection sequence.
92+
struct PayloadProjection {
93+
std::size_t offset = 0;
94+
PayloadProjectionCallback *projection = nullptr;
95+
};
96+
97+
struct VisitorCase {
98+
const OpDescription *description = nullptr;
99+
VisitorCallback *callback = nullptr;
100+
void *callbackData = nullptr;
101+
102+
// If non-negative, a byte offset to apply to the payload. If negative,
103+
// a shifted index into the projections vector.
104+
ssize_t projection = 0;
105+
};
106+
107+
template <typename PayloadT, typename NestedPayloadT>
108+
struct VisitorPayloadOffsetProjection {
109+
static constexpr bool useOffsetof = false;
110+
};
54111

55112
class VisitorBuilderBase {
56113
friend class VisitorBase;
57114
public:
115+
VisitorBuilderBase() = default;
116+
explicit VisitorBuilderBase(VisitorBuilderBase *parent) : m_parent(parent) {}
117+
~VisitorBuilderBase();
118+
58119
void setStrategy(VisitorStrategy strategy) { m_strategy = strategy; }
59120

60-
protected:
61121
void add(const OpDescription &desc, void *extra, VisitorCallback *fn);
62122

123+
public:
124+
PayloadProjectionCallback *m_projection = nullptr;
125+
size_t m_offsetProjection = 0;
126+
63127
private:
128+
VisitorBuilderBase *m_parent = nullptr;
64129
VisitorStrategy m_strategy = VisitorStrategy::ByFunctionDeclaration;
65130
llvm::SmallVector<VisitorCase> m_cases;
131+
llvm::SmallVector<PayloadProjection> m_projections;
66132
};
67133

68134
class VisitorBase {
69135
protected:
70-
VisitorBase(VisitorBuilderBase builder);
136+
VisitorBase(VisitorBuilderBase &&builder);
71137

72138
void visit(void *payload, llvm::Instruction &inst) const;
73139
void visit(void *payload, llvm::Function &fn) const;
74140
void visit(void *payload, llvm::Module &module) const;
75141

76142
private:
143+
void call(const VisitorCase &theCase, void *payload,
144+
llvm::Instruction &inst) const;
145+
77146
VisitorStrategy m_strategy;
78147
llvm::SmallVector<VisitorCase> m_cases;
148+
llvm::SmallVector<PayloadProjection> m_projections;
79149
};
80150

81151
} // namespace detail
@@ -91,7 +161,7 @@ class VisitorBase {
91161
template <typename PayloadT>
92162
class Visitor : public detail::VisitorBase {
93163
public:
94-
Visitor(detail::VisitorBuilderBase builder)
164+
Visitor(detail::VisitorBuilderBase &&builder)
95165
: VisitorBase(std::move(builder)) {}
96166

97167
void visit(PayloadT &payload, llvm::Instruction &inst) const {
@@ -130,16 +200,20 @@ class Visitor : public detail::VisitorBase {
130200
/// myVisitor(myPayload, module);
131201
/// @endcode
132202
template <typename PayloadT>
133-
class VisitorBuilder : public detail::VisitorBuilderBase {
203+
class VisitorBuilder : private detail::VisitorBuilderBase {
204+
template <typename OtherT> friend class VisitorBuilder;
205+
134206
public:
135207
using Payload = PayloadT;
136208

209+
VisitorBuilder() = default;
210+
137211
VisitorBuilder &setStrategy(VisitorStrategy strategy) {
138212
VisitorBuilderBase::setStrategy(strategy);
139213
return *this;
140214
}
141215

142-
Visitor<PayloadT> build() { return {std::move(*this)}; }
216+
Visitor<PayloadT> build() { return std::move(*this); }
143217

144218
template <typename OpT> VisitorBuilder &add(void (*fn)(PayloadT &, OpT &)) {
145219
VisitorBuilderBase::add(
@@ -151,6 +225,32 @@ class VisitorBuilder : public detail::VisitorBuilderBase {
151225
});
152226
return *this;
153227
}
228+
229+
template <typename NestedPayloadT>
230+
VisitorBuilder &nest(void (*registration)(VisitorBuilder<NestedPayloadT> &)) {
231+
VisitorBuilder<NestedPayloadT> nested{this};
232+
233+
if constexpr (detail::VisitorPayloadOffsetProjection<
234+
PayloadT, NestedPayloadT>::useOffsetof) {
235+
nested.m_offsetProjection =
236+
detail::VisitorPayloadOffsetProjection<PayloadT,
237+
NestedPayloadT>::offset;
238+
} else {
239+
nested.m_projection = [](void *payload) -> void * {
240+
return static_cast<void *>(
241+
&VisitorPayloadProjection<PayloadT, NestedPayloadT>::project(
242+
*static_cast<PayloadT *>(payload)));
243+
};
244+
}
245+
246+
(*registration)(nested);
247+
248+
return *this;
249+
}
250+
251+
private:
252+
explicit VisitorBuilder(VisitorBuilderBase *parent)
253+
: VisitorBuilderBase(parent) {}
154254
};
155255

156256
} // namespace llvm_dialects

lib/Dialect/Visitor.cpp

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,74 @@ using namespace llvm;
2929
using llvm_dialects::detail::VisitorBuilderBase;
3030
using llvm_dialects::detail::VisitorBase;
3131

32+
VisitorBuilderBase::~VisitorBuilderBase() {
33+
if (m_parent) {
34+
// Build up the projection sequence by walking all the way to the root.
35+
SmallVector<PayloadProjection> projectionSequence;
36+
projectionSequence.emplace_back();
37+
38+
VisitorBuilderBase *ancestor = this;
39+
for (; ancestor->m_parent; ancestor = ancestor->m_parent) {
40+
if (ancestor->m_projection) {
41+
PayloadProjection next;
42+
next.projection = ancestor->m_projection;
43+
projectionSequence.push_back(next);
44+
} else {
45+
projectionSequence.back().offset += ancestor->m_offsetProjection;
46+
}
47+
}
48+
49+
ssize_t projection = 0;
50+
if (projectionSequence.size() == 1) {
51+
projection = projectionSequence[0].offset;
52+
} else {
53+
projection = -(1 + ancestor->m_projections.size());
54+
ancestor->m_projections.append(projectionSequence.rbegin(),
55+
projectionSequence.rend());
56+
}
57+
58+
for (auto &theCase : m_cases)
59+
theCase.projection = projection;
60+
61+
// Copy all cases to the root.
62+
ancestor->m_cases.append(m_cases.begin(), m_cases.end());
63+
}
64+
}
65+
3266
void VisitorBuilderBase::add(const OpDescription &desc, void *extra, VisitorCallback *fn) {
33-
m_cases.emplace_back(&desc, extra, fn);
67+
VisitorCase theCase;
68+
theCase.description = &desc;
69+
theCase.callback = fn;
70+
theCase.callbackData = extra;
71+
theCase.projection = 0;
72+
m_cases.emplace_back(theCase);
3473
}
3574

36-
VisitorBase::VisitorBase(VisitorBuilderBase builder)
37-
: m_strategy(builder.m_strategy), m_cases(std::move(builder.m_cases)) {
75+
VisitorBase::VisitorBase(VisitorBuilderBase &&builder)
76+
: m_strategy(builder.m_strategy), m_cases(std::move(builder.m_cases)),
77+
m_projections(std::move(builder.m_projections)) {
78+
assert(!builder.m_parent);
79+
}
80+
81+
void VisitorBase::call(const VisitorCase &theCase, void *payload,
82+
Instruction &inst) const {
83+
if (theCase.projection >= 0) {
84+
payload = (char *)payload + theCase.projection;
85+
} else {
86+
for (size_t idx = -theCase.projection - 1;; ++idx) {
87+
payload = (char *)payload + m_projections[idx].offset;
88+
if (!m_projections[idx].projection)
89+
break;
90+
payload = m_projections[idx].projection(payload);
91+
}
92+
}
93+
theCase.callback(theCase.callbackData, payload, &inst);
3894
}
3995

4096
void VisitorBase::visit(void *payload, Instruction &inst) const {
41-
for (const auto &[desc, extra, callback] : m_cases) {
42-
if (desc->matchInstruction(inst))
43-
callback(extra, payload, &inst);
97+
for (const auto &theCase : m_cases) {
98+
if (theCase.description->matchInstruction(inst))
99+
call(theCase, payload, inst);
44100
}
45101
}
46102

@@ -59,15 +115,15 @@ void VisitorBase::visit(void *payload, Function &fn) const {
59115

60116
LLVM_DEBUG(dbgs() << "visit " << decl.getName() << '\n');
61117

62-
for (const auto &[desc, extra, callback] : m_cases) {
63-
if (desc->matchDeclaration(decl)) {
118+
for (const auto &theCase : m_cases) {
119+
if (theCase.description->matchDeclaration(decl)) {
64120
for (Use &use : decl.uses()) {
65121
if (auto *inst = dyn_cast<Instruction>(use.getUser())) {
66122
if (inst->getFunction() != &fn)
67123
continue;
68-
if (auto *call = dyn_cast<CallInst>(inst)) {
69-
if (&use == &call->getCalledOperandUse())
70-
callback(extra, payload, call);
124+
if (auto *callInst = dyn_cast<CallInst>(inst)) {
125+
if (&use == &callInst->getCalledOperandUse())
126+
call(theCase, payload, *callInst);
71127
}
72128
}
73129
}
@@ -89,12 +145,12 @@ void VisitorBase::visit(void *payload, Module &module) const {
89145
if (!decl.isDeclaration())
90146
continue;
91147

92-
for (const auto &[desc, extra, callback] : m_cases) {
93-
if (desc->matchDeclaration(decl)) {
148+
for (const auto &theCase : m_cases) {
149+
if (theCase.description->matchDeclaration(decl)) {
94150
for (Use &use : decl.uses()) {
95-
if (auto *call = dyn_cast<CallInst>(use.getUser())) {
96-
if (&use == &call->getCalledOperandUse())
97-
callback(extra, payload, call);
151+
if (auto *callInst = dyn_cast<CallInst>(use.getUser())) {
152+
if (&use == &callInst->getCalledOperandUse())
153+
call(theCase, payload, *callInst);
98154
}
99155
}
100156
}

0 commit comments

Comments
 (0)