|
18 | 18 | #include "mlir/IR/Operation.h" |
19 | 19 | #include "mlir/IR/SymbolTable.h" |
20 | 20 | #include "mlir/IR/Value.h" |
| 21 | +#include "mlir/Support/StateStack.h" |
21 | 22 | #include "mlir/Target/LLVMIR/Export.h" |
22 | 23 | #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" |
23 | 24 | #include "mlir/Target/LLVMIR/TypeToLLVM.h" |
@@ -271,80 +272,29 @@ class ModuleTranslation { |
271 | 272 | /// it if it does not exist. |
272 | 273 | llvm::NamedMDNode *getOrInsertNamedModuleMetadata(StringRef name); |
273 | 274 |
|
274 | | - /// Common CRTP base class for ModuleTranslation stack frames. |
275 | | - class StackFrame { |
276 | | - public: |
277 | | - virtual ~StackFrame() = default; |
278 | | - TypeID getTypeID() const { return typeID; } |
279 | | - |
280 | | - protected: |
281 | | - explicit StackFrame(TypeID typeID) : typeID(typeID) {} |
282 | | - |
283 | | - private: |
284 | | - const TypeID typeID; |
285 | | - virtual void anchor(); |
286 | | - }; |
287 | | - |
288 | | - /// Concrete CRTP base class for ModuleTranslation stack frames. When |
289 | | - /// translating operations with regions, users of ModuleTranslation can store |
290 | | - /// state on ModuleTranslation stack before entering the region and inspect |
291 | | - /// it when converting operations nested within that region. Users are |
292 | | - /// expected to derive this class and put any relevant information into fields |
293 | | - /// of the derived class. The usual isa/dyn_cast functionality is available |
294 | | - /// for instances of derived classes. |
295 | | - template <typename Derived> |
296 | | - class StackFrameBase : public StackFrame { |
297 | | - public: |
298 | | - explicit StackFrameBase() : StackFrame(TypeID::get<Derived>()) {} |
299 | | - }; |
300 | | - |
301 | 275 | /// Creates a stack frame of type `T` on ModuleTranslation stack. `T` must |
302 | 276 | /// be derived from `StackFrameBase<T>` and constructible from the provided |
303 | 277 | /// arguments. Doing this before entering the region of the op being |
304 | 278 | /// translated makes the frame available when translating ops within that |
305 | 279 | /// region. |
306 | 280 | template <typename T, typename... Args> |
307 | 281 | void stackPush(Args &&...args) { |
308 | | - static_assert( |
309 | | - std::is_base_of<StackFrame, T>::value, |
310 | | - "can only push instances of StackFrame on ModuleTranslation stack"); |
311 | | - stack.push_back(std::make_unique<T>(std::forward<Args>(args)...)); |
| 282 | + stack.stackPush<T>(std::forward<Args>(args)...); |
312 | 283 | } |
313 | 284 |
|
314 | 285 | /// Pops the last element from the ModuleTranslation stack. |
315 | | - void stackPop() { stack.pop_back(); } |
| 286 | + void stackPop() { stack.stackPop(); } |
316 | 287 |
|
317 | 288 | /// Calls `callback` for every ModuleTranslation stack frame of type `T` |
318 | 289 | /// starting from the top of the stack. |
319 | 290 | template <typename T> |
320 | 291 | WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) { |
321 | | - static_assert(std::is_base_of<StackFrame, T>::value, |
322 | | - "expected T derived from StackFrame"); |
323 | | - if (!callback) |
324 | | - return WalkResult::skip(); |
325 | | - for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) { |
326 | | - if (T *ptr = dyn_cast_or_null<T>(frame.get())) { |
327 | | - WalkResult result = callback(*ptr); |
328 | | - if (result.wasInterrupted()) |
329 | | - return result; |
330 | | - } |
331 | | - } |
332 | | - return WalkResult::advance(); |
| 292 | + return stack.stackWalk(callback); |
333 | 293 | } |
334 | 294 |
|
335 | 295 | /// RAII object calling stackPush/stackPop on construction/destruction. |
336 | 296 | template <typename T> |
337 | | - struct SaveStack { |
338 | | - template <typename... Args> |
339 | | - explicit SaveStack(ModuleTranslation &m, Args &&...args) |
340 | | - : moduleTranslation(m) { |
341 | | - moduleTranslation.stackPush<T>(std::forward<Args>(args)...); |
342 | | - } |
343 | | - ~SaveStack() { moduleTranslation.stackPop(); } |
344 | | - |
345 | | - private: |
346 | | - ModuleTranslation &moduleTranslation; |
347 | | - }; |
| 297 | + using SaveStack = SaveStateStack<T, ModuleTranslation>; |
348 | 298 |
|
349 | 299 | SymbolTableCollection &symbolTable() { return symbolTableCollection; } |
350 | 300 |
|
@@ -468,7 +418,7 @@ class ModuleTranslation { |
468 | 418 |
|
469 | 419 | /// Stack of user-specified state elements, useful when translating operations |
470 | 420 | /// with regions. |
471 | | - SmallVector<std::unique_ptr<StackFrame>> stack; |
| 421 | + StateStack stack; |
472 | 422 |
|
473 | 423 | /// A cache for the symbol tables constructed during symbols lookup. |
474 | 424 | SymbolTableCollection symbolTableCollection; |
@@ -510,14 +460,4 @@ llvm::CallInst *createIntrinsicCall( |
510 | 460 | } // namespace LLVM |
511 | 461 | } // namespace mlir |
512 | 462 |
|
513 | | -namespace llvm { |
514 | | -template <typename T> |
515 | | -struct isa_impl<T, ::mlir::LLVM::ModuleTranslation::StackFrame> { |
516 | | - static inline bool |
517 | | - doit(const ::mlir::LLVM::ModuleTranslation::StackFrame &frame) { |
518 | | - return frame.getTypeID() == ::mlir::TypeID::get<T>(); |
519 | | - } |
520 | | -}; |
521 | | -} // namespace llvm |
522 | | - |
523 | 463 | #endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H |
0 commit comments