Skip to content

Commit daa38d7

Browse files
authored
Add VisitorResult (#81)
Allow a visitor callback to decide whether additional matching callbacks for the same instruction should be called or not. This is useful when a visitor callback wants to erase an instruction and there are multiple visitors that might match the same instruction (as sometimes happens with generic instructions like load and store).
1 parent 69e114f commit daa38d7

File tree

4 files changed

+132
-35
lines changed

4 files changed

+132
-35
lines changed

example/ExampleMain.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ struct VisitorNest {
148148
void visitBinaryOperator(BinaryOperator &inst) {
149149
*out << "visiting BinaryOperator: " << inst << '\n';
150150
}
151+
VisitorResult visitUnaryInstruction(UnaryInstruction &inst) {
152+
*out << "visiting UnaryInstruction (pre): " << inst << '\n';
153+
return isa<LoadInst>(inst) ? VisitorResult::Stop : VisitorResult::Continue;
154+
}
151155
};
152156

153157
struct VisitorContainer {
@@ -181,6 +185,12 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
181185
b.add<xd::ReadOp>([](VisitorNest &self, xd::ReadOp &op) {
182186
*self.out << "visiting ReadOp: " << op << '\n';
183187
});
188+
b.add(&VisitorNest::visitUnaryInstruction);
189+
b.add<xd::SetReadOp>([](VisitorNest &self, xd::SetReadOp &op) {
190+
*self.out << "visiting SetReadOp: " << op << '\n';
191+
return op.getType()->isIntegerTy(1) ? VisitorResult::Stop
192+
: VisitorResult::Continue;
193+
});
184194
b.addSet<xd::SetReadOp, xd::SetWriteOp>(
185195
[](VisitorNest &self, llvm::Instruction &op) {
186196
if (isa<xd::SetReadOp>(op)) {

include/llvm-dialects/Dialect/Visitor.h

Lines changed: 106 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022 Advanced Micro Devices, Inc. All Rights Reserved.
2+
* Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -89,6 +89,25 @@ struct VisitorPayloadProjection {
8989
static constexpr std::size_t offset = offsetof(PayloadT, field); \
9090
};
9191

92+
/// @brief Possible result states of visitor callbacks
93+
///
94+
/// A visitor may have multiple callbacks registered that match on the same
95+
/// instruction. By default, all matching callbacks are invoked in the order in
96+
/// which they were registered with the visitor. This may not be appropriate.
97+
/// A common issue is when the callback erases and replaces the visited
98+
/// instruction.
99+
///
100+
/// Callbacks may explicitly return a result state to indicate whether further
101+
/// visits are desired.
102+
enum class VisitorResult {
103+
/// Continue with the next callbacks on the same instruction. This is the
104+
/// default when the callback does not return a value.
105+
Continue,
106+
107+
/// Skip subsequent callbacks
108+
Stop,
109+
};
110+
92111
namespace detail {
93112

94113
class VisitorBase;
@@ -158,8 +177,8 @@ struct VisitorCallbackData : public Foo0, Foo1 {
158177
char data[Size];
159178
};
160179

161-
using VisitorCallback = void(const VisitorCallbackData &, void *,
162-
llvm::Instruction *);
180+
using VisitorCallback = VisitorResult(const VisitorCallbackData &, void *,
181+
llvm::Instruction *);
163182
using PayloadProjectionCallback = void *(void *);
164183

165184
struct VisitorHandler {
@@ -290,8 +309,8 @@ class VisitorBase {
290309

291310
void call(HandlerRange handlers, void *payload,
292311
llvm::Instruction &inst) const;
293-
void call(const VisitorHandler &handler, void *payload,
294-
llvm::Instruction &inst) const;
312+
VisitorResult call(const VisitorHandler &handler, void *payload,
313+
llvm::Instruction &inst) const;
295314

296315
template <typename FilterT>
297316
void visitByDeclarations(void *payload, llvm::Module &module,
@@ -369,34 +388,74 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
369388

370389
Visitor<PayloadT> build() { return VisitorBuilderBase::build(); }
371390

391+
template <typename OpT>
392+
VisitorBuilder &add(VisitorResult (*fn)(PayloadT &, OpT &)) {
393+
addCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
394+
return *this;
395+
}
396+
372397
template <typename OpT> VisitorBuilder &add(void (*fn)(PayloadT &, OpT &)) {
373398
addCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
374399
return *this;
375400
}
376401

402+
template <typename... OpTs>
403+
VisitorBuilder &addSet(VisitorResult (*fn)(PayloadT &,
404+
llvm::Instruction &I)) {
405+
addSetCase(detail::VisitorKey::opSet<OpTs...>(), fn);
406+
return *this;
407+
}
408+
377409
template <typename... OpTs>
378410
VisitorBuilder &addSet(void (*fn)(PayloadT &, llvm::Instruction &I)) {
379411
addSetCase(detail::VisitorKey::opSet<OpTs...>(), fn);
380412
return *this;
381413
}
382414

415+
VisitorBuilder &addSet(const OpSet &opSet,
416+
VisitorResult (*fn)(PayloadT &,
417+
llvm::Instruction &I)) {
418+
addSetCase(detail::VisitorKey::opSet(opSet), fn);
419+
return *this;
420+
}
421+
383422
VisitorBuilder &addSet(const OpSet &opSet,
384423
void (*fn)(PayloadT &, llvm::Instruction &I)) {
385424
addSetCase(detail::VisitorKey::opSet(opSet), fn);
386425
return *this;
387426
}
388427

428+
template <typename OpT>
429+
VisitorBuilder &add(VisitorResult (PayloadT::*fn)(OpT &)) {
430+
addMemberFnCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
431+
return *this;
432+
}
433+
389434
template <typename OpT> VisitorBuilder &add(void (PayloadT::*fn)(OpT &)) {
390435
addMemberFnCase<OpT>(detail::VisitorKey::op<OpT>(), fn);
391436
return *this;
392437
}
393438

439+
VisitorBuilder &addIntrinsic(unsigned id,
440+
VisitorResult (*fn)(PayloadT &,
441+
llvm::IntrinsicInst &)) {
442+
addCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
443+
return *this;
444+
}
445+
394446
VisitorBuilder &addIntrinsic(unsigned id,
395447
void (*fn)(PayloadT &, llvm::IntrinsicInst &)) {
396448
addCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
397449
return *this;
398450
}
399451

452+
VisitorBuilder &
453+
addIntrinsic(unsigned id,
454+
VisitorResult (PayloadT::*fn)(llvm::IntrinsicInst &)) {
455+
addMemberFnCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
456+
return *this;
457+
}
458+
400459
VisitorBuilder &addIntrinsic(unsigned id,
401460
void (PayloadT::*fn)(llvm::IntrinsicInst &)) {
402461
addMemberFnCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn);
@@ -433,52 +492,72 @@ class VisitorBuilder : private detail::VisitorBuilderBase {
433492
detail::PayloadProjectionCallback *projection)
434493
: VisitorBuilderBase(parent, projection) {}
435494

436-
template <typename OpT>
437-
void addCase(detail::VisitorKey key, void (*fn)(PayloadT &, OpT &)) {
495+
template <typename OpT, typename ReturnT>
496+
void addCase(detail::VisitorKey key, ReturnT (*fn)(PayloadT &, OpT &)) {
438497
detail::VisitorCallbackData data{};
439498
static_assert(sizeof(fn) <= sizeof(data.data));
440499
memcpy(&data.data, &fn, sizeof(fn));
441-
VisitorBuilderBase::add(key, &VisitorBuilder::forwarder<OpT>, data);
500+
VisitorBuilderBase::add(key, &VisitorBuilder::forwarder<OpT, ReturnT>,
501+
data);
442502
}
443503

504+
template <typename ReturnT>
444505
void addSetCase(detail::VisitorKey key,
445-
void (*fn)(PayloadT &, llvm::Instruction &)) {
506+
ReturnT (*fn)(PayloadT &, llvm::Instruction &)) {
446507
detail::VisitorCallbackData data{};
447508
static_assert(sizeof(fn) <= sizeof(data.data));
448509
memcpy(&data.data, &fn, sizeof(fn));
449-
VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder, data);
510+
VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder<ReturnT>, data);
450511
}
451512

452-
template <typename OpT>
453-
void addMemberFnCase(detail::VisitorKey key, void (PayloadT::*fn)(OpT &)) {
513+
template <typename OpT, typename ReturnT>
514+
void addMemberFnCase(detail::VisitorKey key, ReturnT (PayloadT::*fn)(OpT &)) {
454515
detail::VisitorCallbackData data{};
455516
static_assert(sizeof(fn) <= sizeof(data.data));
456517
memcpy(&data.data, &fn, sizeof(fn));
457-
VisitorBuilderBase::add(key, &VisitorBuilder::memberFnForwarder<OpT>, data);
518+
VisitorBuilderBase::add(
519+
key, &VisitorBuilder::memberFnForwarder<OpT, ReturnT>, data);
458520
}
459521

460-
template <typename OpT>
461-
static void forwarder(const detail::VisitorCallbackData &data, void *payload,
462-
llvm::Instruction *op) {
463-
void (*fn)(PayloadT &, OpT &);
522+
template <typename OpT, typename ReturnT>
523+
static VisitorResult forwarder(const detail::VisitorCallbackData &data,
524+
void *payload, llvm::Instruction *op) {
525+
ReturnT (*fn)(PayloadT &, OpT &);
464526
memcpy(&fn, &data.data, sizeof(fn));
465-
fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op));
527+
if constexpr (std::is_same_v<ReturnT, void>) {
528+
fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op));
529+
return VisitorResult::Continue;
530+
} else {
531+
return fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op));
532+
}
466533
}
467534

468-
static void setForwarder(const detail::VisitorCallbackData &data,
469-
void *payload, llvm::Instruction *op) {
470-
void (*fn)(PayloadT &, llvm::Instruction &);
535+
template <typename ReturnT>
536+
static VisitorResult setForwarder(const detail::VisitorCallbackData &data,
537+
void *payload, llvm::Instruction *op) {
538+
ReturnT (*fn)(PayloadT &, llvm::Instruction &);
471539
memcpy(&fn, &data.data, sizeof(fn));
472-
fn(*static_cast<PayloadT *>(payload), *op);
540+
if constexpr (std::is_same_v<ReturnT, void>) {
541+
fn(*static_cast<PayloadT *>(payload), *op);
542+
return VisitorResult::Continue;
543+
} else {
544+
return fn(*static_cast<PayloadT *>(payload), *op);
545+
}
473546
}
474547

475-
template <typename OpT>
476-
static void memberFnForwarder(const detail::VisitorCallbackData &data,
477-
void *payload, llvm::Instruction *op) {
478-
void (PayloadT::*fn)(OpT &);
548+
template <typename OpT, typename ReturnT>
549+
static VisitorResult
550+
memberFnForwarder(const detail::VisitorCallbackData &data, void *payload,
551+
llvm::Instruction *op) {
552+
ReturnT (PayloadT::*fn)(OpT &);
479553
memcpy(&fn, &data.data, sizeof(fn));
480554
PayloadT *self = static_cast<PayloadT *>(payload);
481-
(self->*fn)(*llvm::cast<OpT>(op));
555+
if constexpr (std::is_same_v<ReturnT, void>) {
556+
(self->*fn)(*llvm::cast<OpT>(op));
557+
return VisitorResult::Continue;
558+
} else {
559+
return (self->*fn)(*llvm::cast<OpT>(op));
560+
}
482561
}
483562
};
484563

lib/Dialect/Visitor.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022 Advanced Micro Devices, Inc. All Rights Reserved.
2+
* Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -199,12 +199,15 @@ VisitorBase::VisitorBase(VisitorTemplate &&templ)
199199

200200
void VisitorBase::call(HandlerRange handlers, void *payload,
201201
Instruction &inst) const {
202-
for (unsigned idx = handlers.first; idx != handlers.second; ++idx)
203-
call(m_handlers[idx], payload, inst);
202+
for (unsigned idx = handlers.first; idx != handlers.second; ++idx) {
203+
VisitorResult result = call(m_handlers[idx], payload, inst);
204+
if (result == VisitorResult::Stop)
205+
return;
206+
}
204207
}
205208

206-
void VisitorBase::call(const VisitorHandler &handler, void *payload,
207-
Instruction &inst) const {
209+
VisitorResult VisitorBase::call(const VisitorHandler &handler, void *payload,
210+
Instruction &inst) const {
208211
if (handler.projection.isOffset()) {
209212
payload = (char *)payload + handler.projection.getOffset();
210213
} else {
@@ -216,7 +219,7 @@ void VisitorBase::call(const VisitorHandler &handler, void *payload,
216219
}
217220
}
218221

219-
handler.callback(handler.data, payload, &inst);
222+
return handler.callback(handler.data, payload, &inst);
220223
}
221224

222225
void VisitorBase::visit(void *payload, Instruction &inst) const {

test/example/visitor-basic.ll

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
; RUN: llvm-dialects-example -visit %s | FileCheck --check-prefixes=DEFAULT %s
22

33
; DEFAULT: visiting ReadOp: %v = call i32 @xd.read.i32()
4-
; DEFAULT-NEXT: visiting UnaryInstruction: %w = load i32, ptr %p
5-
; DEFAULT-NEXT: visiting UnaryInstruction: %q = load i32, ptr %p1
4+
; DEFAULT-NEXT: visiting UnaryInstruction (pre): %w = load i32, ptr %p
5+
; DEFAULT-NEXT: visiting UnaryInstruction (pre): %q = load i32, ptr %p1
66
; DEFAULT-NEXT: visiting BinaryOperator: %v1 = add i32 %v, %w
77
; DEFAULT-NEXT: visiting umax intrinsic: %v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q)
88
; DEFAULT-NEXT: visiting WriteOp: call void (...) @xd.write(i8 %t)
9+
; DEFAULT-NEXT: visiting SetReadOp: %v.0 = call i1 @xd.set.read.i1()
10+
; DEFAULT-NEXT: visiting SetReadOp: %v.1 = call i32 @xd.set.read.i32()
911
; DEFAULT-NEXT: visiting SetReadOp (set): %v.1 = call i32 @xd.set.read.i32()
12+
; DEFAULT-NEXT: visiting UnaryInstruction (pre): %v.2 = trunc i32 %v.1 to i8
1013
; DEFAULT-NEXT: visiting UnaryInstruction: %v.2 = trunc i32 %v.1 to i8
1114
; DEFAULT-NEXT: visiting SetWriteOp (set): call void (...) @xd.set.write(i8 %v.2)
1215
; DEFAULT-NEXT: visiting WriteVarArgOp: call void (...) @xd.write.vararg(i8 %t, i32 %v2, i32 %q)
@@ -27,6 +30,7 @@ entry:
2730
%v2 = call i32 @llvm.umax.i32(i32 %v1, i32 %q)
2831
%t = call i8 (...) @xd.itrunc.i8(i32 %v2)
2932
call void (...) @xd.write(i8 %t)
33+
%v.0 = call i1 @xd.set.read.i1()
3034
%v.1 = call i32 @xd.set.read.i32()
3135
%v.2 = trunc i32 %v.1 to i8
3236
call void (...) @xd.set.write(i8 %v.2)
@@ -36,6 +40,7 @@ entry:
3640
}
3741

3842
declare i32 @xd.read.i32()
43+
declare i1 @xd.set.read.i1()
3944
declare i32 @xd.set.read.i32()
4045
declare void @xd.write(...)
4146
declare void @xd.set.write(...)

0 commit comments

Comments
 (0)