|
1 | 1 | /* |
2 | | - * Copyright (c) 2022 Advanced Micro Devices, Inc. All Rights Reserved. |
| 2 | + * Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved. |
3 | 3 | * |
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | * you may not use this file except in compliance with the License. |
@@ -89,6 +89,25 @@ struct VisitorPayloadProjection { |
89 | 89 | static constexpr std::size_t offset = offsetof(PayloadT, field); \ |
90 | 90 | }; |
91 | 91 |
|
| 92 | +/// @brief Possible result states of visitor callbacks |
| 93 | +/// |
| 94 | +/// A visitor may have multiple callbacks registered that match on the same |
| 95 | +/// instruction. By default, all matching callbacks are invoked in the order in |
| 96 | +/// which they were registered with the visitor. This may not be appropriate. |
| 97 | +/// A common issue is when the callback erases and replaces the visited |
| 98 | +/// instruction. |
| 99 | +/// |
| 100 | +/// Callbacks may explicitly return a result state to indicate whether further |
| 101 | +/// visits are desired. |
| 102 | +enum class VisitorResult { |
| 103 | + /// Continue with the next callbacks on the same instruction. This is the |
| 104 | + /// default when the callback does not return a value. |
| 105 | + Continue, |
| 106 | + |
| 107 | + /// Skip subsequent callbacks |
| 108 | + Stop, |
| 109 | +}; |
| 110 | + |
92 | 111 | namespace detail { |
93 | 112 |
|
94 | 113 | class VisitorBase; |
@@ -158,8 +177,8 @@ struct VisitorCallbackData : public Foo0, Foo1 { |
158 | 177 | char data[Size]; |
159 | 178 | }; |
160 | 179 |
|
161 | | -using VisitorCallback = void(const VisitorCallbackData &, void *, |
162 | | - llvm::Instruction *); |
| 180 | +using VisitorCallback = VisitorResult(const VisitorCallbackData &, void *, |
| 181 | + llvm::Instruction *); |
163 | 182 | using PayloadProjectionCallback = void *(void *); |
164 | 183 |
|
165 | 184 | struct VisitorHandler { |
@@ -290,8 +309,8 @@ class VisitorBase { |
290 | 309 |
|
291 | 310 | void call(HandlerRange handlers, void *payload, |
292 | 311 | llvm::Instruction &inst) const; |
293 | | - void call(const VisitorHandler &handler, void *payload, |
294 | | - llvm::Instruction &inst) const; |
| 312 | + VisitorResult call(const VisitorHandler &handler, void *payload, |
| 313 | + llvm::Instruction &inst) const; |
295 | 314 |
|
296 | 315 | template <typename FilterT> |
297 | 316 | void visitByDeclarations(void *payload, llvm::Module &module, |
@@ -369,34 +388,74 @@ class VisitorBuilder : private detail::VisitorBuilderBase { |
369 | 388 |
|
370 | 389 | Visitor<PayloadT> build() { return VisitorBuilderBase::build(); } |
371 | 390 |
|
| 391 | + template <typename OpT> |
| 392 | + VisitorBuilder &add(VisitorResult (*fn)(PayloadT &, OpT &)) { |
| 393 | + addCase<OpT>(detail::VisitorKey::op<OpT>(), fn); |
| 394 | + return *this; |
| 395 | + } |
| 396 | + |
372 | 397 | template <typename OpT> VisitorBuilder &add(void (*fn)(PayloadT &, OpT &)) { |
373 | 398 | addCase<OpT>(detail::VisitorKey::op<OpT>(), fn); |
374 | 399 | return *this; |
375 | 400 | } |
376 | 401 |
|
| 402 | + template <typename... OpTs> |
| 403 | + VisitorBuilder &addSet(VisitorResult (*fn)(PayloadT &, |
| 404 | + llvm::Instruction &I)) { |
| 405 | + addSetCase(detail::VisitorKey::opSet<OpTs...>(), fn); |
| 406 | + return *this; |
| 407 | + } |
| 408 | + |
377 | 409 | template <typename... OpTs> |
378 | 410 | VisitorBuilder &addSet(void (*fn)(PayloadT &, llvm::Instruction &I)) { |
379 | 411 | addSetCase(detail::VisitorKey::opSet<OpTs...>(), fn); |
380 | 412 | return *this; |
381 | 413 | } |
382 | 414 |
|
| 415 | + VisitorBuilder &addSet(const OpSet &opSet, |
| 416 | + VisitorResult (*fn)(PayloadT &, |
| 417 | + llvm::Instruction &I)) { |
| 418 | + addSetCase(detail::VisitorKey::opSet(opSet), fn); |
| 419 | + return *this; |
| 420 | + } |
| 421 | + |
383 | 422 | VisitorBuilder &addSet(const OpSet &opSet, |
384 | 423 | void (*fn)(PayloadT &, llvm::Instruction &I)) { |
385 | 424 | addSetCase(detail::VisitorKey::opSet(opSet), fn); |
386 | 425 | return *this; |
387 | 426 | } |
388 | 427 |
|
| 428 | + template <typename OpT> |
| 429 | + VisitorBuilder &add(VisitorResult (PayloadT::*fn)(OpT &)) { |
| 430 | + addMemberFnCase<OpT>(detail::VisitorKey::op<OpT>(), fn); |
| 431 | + return *this; |
| 432 | + } |
| 433 | + |
389 | 434 | template <typename OpT> VisitorBuilder &add(void (PayloadT::*fn)(OpT &)) { |
390 | 435 | addMemberFnCase<OpT>(detail::VisitorKey::op<OpT>(), fn); |
391 | 436 | return *this; |
392 | 437 | } |
393 | 438 |
|
| 439 | + VisitorBuilder &addIntrinsic(unsigned id, |
| 440 | + VisitorResult (*fn)(PayloadT &, |
| 441 | + llvm::IntrinsicInst &)) { |
| 442 | + addCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn); |
| 443 | + return *this; |
| 444 | + } |
| 445 | + |
394 | 446 | VisitorBuilder &addIntrinsic(unsigned id, |
395 | 447 | void (*fn)(PayloadT &, llvm::IntrinsicInst &)) { |
396 | 448 | addCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn); |
397 | 449 | return *this; |
398 | 450 | } |
399 | 451 |
|
| 452 | + VisitorBuilder & |
| 453 | + addIntrinsic(unsigned id, |
| 454 | + VisitorResult (PayloadT::*fn)(llvm::IntrinsicInst &)) { |
| 455 | + addMemberFnCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn); |
| 456 | + return *this; |
| 457 | + } |
| 458 | + |
400 | 459 | VisitorBuilder &addIntrinsic(unsigned id, |
401 | 460 | void (PayloadT::*fn)(llvm::IntrinsicInst &)) { |
402 | 461 | addMemberFnCase<llvm::IntrinsicInst>(detail::VisitorKey::intrinsic(id), fn); |
@@ -433,52 +492,72 @@ class VisitorBuilder : private detail::VisitorBuilderBase { |
433 | 492 | detail::PayloadProjectionCallback *projection) |
434 | 493 | : VisitorBuilderBase(parent, projection) {} |
435 | 494 |
|
436 | | - template <typename OpT> |
437 | | - void addCase(detail::VisitorKey key, void (*fn)(PayloadT &, OpT &)) { |
| 495 | + template <typename OpT, typename ReturnT> |
| 496 | + void addCase(detail::VisitorKey key, ReturnT (*fn)(PayloadT &, OpT &)) { |
438 | 497 | detail::VisitorCallbackData data{}; |
439 | 498 | static_assert(sizeof(fn) <= sizeof(data.data)); |
440 | 499 | memcpy(&data.data, &fn, sizeof(fn)); |
441 | | - VisitorBuilderBase::add(key, &VisitorBuilder::forwarder<OpT>, data); |
| 500 | + VisitorBuilderBase::add(key, &VisitorBuilder::forwarder<OpT, ReturnT>, |
| 501 | + data); |
442 | 502 | } |
443 | 503 |
|
| 504 | + template <typename ReturnT> |
444 | 505 | void addSetCase(detail::VisitorKey key, |
445 | | - void (*fn)(PayloadT &, llvm::Instruction &)) { |
| 506 | + ReturnT (*fn)(PayloadT &, llvm::Instruction &)) { |
446 | 507 | detail::VisitorCallbackData data{}; |
447 | 508 | static_assert(sizeof(fn) <= sizeof(data.data)); |
448 | 509 | memcpy(&data.data, &fn, sizeof(fn)); |
449 | | - VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder, data); |
| 510 | + VisitorBuilderBase::add(key, &VisitorBuilder::setForwarder<ReturnT>, data); |
450 | 511 | } |
451 | 512 |
|
452 | | - template <typename OpT> |
453 | | - void addMemberFnCase(detail::VisitorKey key, void (PayloadT::*fn)(OpT &)) { |
| 513 | + template <typename OpT, typename ReturnT> |
| 514 | + void addMemberFnCase(detail::VisitorKey key, ReturnT (PayloadT::*fn)(OpT &)) { |
454 | 515 | detail::VisitorCallbackData data{}; |
455 | 516 | static_assert(sizeof(fn) <= sizeof(data.data)); |
456 | 517 | memcpy(&data.data, &fn, sizeof(fn)); |
457 | | - VisitorBuilderBase::add(key, &VisitorBuilder::memberFnForwarder<OpT>, data); |
| 518 | + VisitorBuilderBase::add( |
| 519 | + key, &VisitorBuilder::memberFnForwarder<OpT, ReturnT>, data); |
458 | 520 | } |
459 | 521 |
|
460 | | - template <typename OpT> |
461 | | - static void forwarder(const detail::VisitorCallbackData &data, void *payload, |
462 | | - llvm::Instruction *op) { |
463 | | - void (*fn)(PayloadT &, OpT &); |
| 522 | + template <typename OpT, typename ReturnT> |
| 523 | + static VisitorResult forwarder(const detail::VisitorCallbackData &data, |
| 524 | + void *payload, llvm::Instruction *op) { |
| 525 | + ReturnT (*fn)(PayloadT &, OpT &); |
464 | 526 | memcpy(&fn, &data.data, sizeof(fn)); |
465 | | - fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op)); |
| 527 | + if constexpr (std::is_same_v<ReturnT, void>) { |
| 528 | + fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op)); |
| 529 | + return VisitorResult::Continue; |
| 530 | + } else { |
| 531 | + return fn(*static_cast<PayloadT *>(payload), *llvm::cast<OpT>(op)); |
| 532 | + } |
466 | 533 | } |
467 | 534 |
|
468 | | - static void setForwarder(const detail::VisitorCallbackData &data, |
469 | | - void *payload, llvm::Instruction *op) { |
470 | | - void (*fn)(PayloadT &, llvm::Instruction &); |
| 535 | + template <typename ReturnT> |
| 536 | + static VisitorResult setForwarder(const detail::VisitorCallbackData &data, |
| 537 | + void *payload, llvm::Instruction *op) { |
| 538 | + ReturnT (*fn)(PayloadT &, llvm::Instruction &); |
471 | 539 | memcpy(&fn, &data.data, sizeof(fn)); |
472 | | - fn(*static_cast<PayloadT *>(payload), *op); |
| 540 | + if constexpr (std::is_same_v<ReturnT, void>) { |
| 541 | + fn(*static_cast<PayloadT *>(payload), *op); |
| 542 | + return VisitorResult::Continue; |
| 543 | + } else { |
| 544 | + return fn(*static_cast<PayloadT *>(payload), *op); |
| 545 | + } |
473 | 546 | } |
474 | 547 |
|
475 | | - template <typename OpT> |
476 | | - static void memberFnForwarder(const detail::VisitorCallbackData &data, |
477 | | - void *payload, llvm::Instruction *op) { |
478 | | - void (PayloadT::*fn)(OpT &); |
| 548 | + template <typename OpT, typename ReturnT> |
| 549 | + static VisitorResult |
| 550 | + memberFnForwarder(const detail::VisitorCallbackData &data, void *payload, |
| 551 | + llvm::Instruction *op) { |
| 552 | + ReturnT (PayloadT::*fn)(OpT &); |
479 | 553 | memcpy(&fn, &data.data, sizeof(fn)); |
480 | 554 | PayloadT *self = static_cast<PayloadT *>(payload); |
481 | | - (self->*fn)(*llvm::cast<OpT>(op)); |
| 555 | + if constexpr (std::is_same_v<ReturnT, void>) { |
| 556 | + (self->*fn)(*llvm::cast<OpT>(op)); |
| 557 | + return VisitorResult::Continue; |
| 558 | + } else { |
| 559 | + return (self->*fn)(*llvm::cast<OpT>(op)); |
| 560 | + } |
482 | 561 | } |
483 | 562 | }; |
484 | 563 |
|
|
0 commit comments