diff --git a/csrc/host_ir/allocate_and_deallocate.cpp b/csrc/host_ir/allocate_and_deallocate.cpp index d68045dc660..e0a14abf18b 100644 --- a/csrc/host_ir/allocate_and_deallocate.cpp +++ b/csrc/host_ir/allocate_and_deallocate.cpp @@ -8,10 +8,10 @@ #include "host_ir/allocate_and_deallocate.h" -#include #include #include #include +#include #include #include #include @@ -24,91 +24,93 @@ namespace nvfuser::hir { namespace { -class DominatorTree { +class Node { public: - class Node { - public: - Node(Scope* scope, Scope::Iterator iterator) - : scope_(scope), iterator_(iterator) {} - Node(const Node& other) = delete; - Node(Node&& other) = delete; - Node& operator=(const Node& other) = delete; - Node& operator=(Node&& other) = delete; - - const std::vector& children() const { - return children_; - } + Node(Scope* scope, Scope::Iterator iterator, const Node* parent) + : scope_(scope), iterator_(iterator), parent_(parent) {} + Node(const Node& other) = delete; + Node(Node&& other) = delete; + Node& operator=(const Node& other) = delete; + Node& operator=(Node&& other) = delete; + + const std::vector& children() const { + return children_; + } - void addChild(Node* child) { - children_.push_back(child); - } + void addChild(Node* child) { + children_.push_back(child); + } - Scope* scope() const { - return scope_; - } + Scope* scope() const { + return scope_; + } - Scope::Iterator iterator() const { - return iterator_; - } + Scope::Iterator iterator() const { + return iterator_; + } - Expr* getExpr() const { - return *iterator_; - } + Expr* getExpr() const { + return *iterator_; + } - private: - // Consider putting `scope` and `iterator` into a separate Mutator class. - // They are only needed when the user wants to modify the host IR. - Scope* scope_; - Scope::Iterator iterator_; + const Node* parent() const { + return parent_; + } - std::vector children_; + private: + Scope* scope_; + Scope::Iterator iterator_; + const Node* parent_; + std::vector children_; +}; + +void depthFirstTraverse( + const Node* root, + const std::function& pre_fn, + const std::function& post_fn) { + struct Frame { + const Node* node; + bool processed; }; - explicit DominatorTree(hir::HostIrContainer& hic) : hic_(hic) { - build(hic_.topLevel(), /*parent=*/nullptr); + std::stack stack; + stack.push({root, /*processed=*/false}); + while (!stack.empty()) { + Frame& top = stack.top(); + if (top.processed) { + post_fn(top.node); + stack.pop(); + continue; + } + + pre_fn(top.node); + top.processed = true; + for (const Node* child : top.node->children()) { + stack.push({child, /*processed=*/false}); + } + } +} + +class DominatorTree { + public: + explicit DominatorTree(hir::HostIrContainer& hic) : hic_(&hic) { + build(hic_->topLevel(), /*parent=*/nullptr); } const Node* getRoot() const { - const auto& top_level_exprs = hic_.topLevelExprs(); + const auto& top_level_exprs = hic_->topLevelExprs(); NVF_ERROR(!top_level_exprs.empty()); Expr* root = top_level_exprs.front(); return &nodes_.at(root); } - // `pre_fn` is called before traversing any child of a node. `post_fn` is - // called after traversing all children of a node. - void depthFirstTraverse( - const std::function& pre_fn, - const std::function& post_fn) const { - struct Frame { - const Node* node; - bool processed; - }; - - std::stack stack; - stack.emplace(getRoot(), /*processed=*/false); - while (!stack.empty()) { - Frame& top = stack.top(); - if (top.processed) { - post_fn(top.node); - stack.pop(); - continue; - } - - pre_fn(top.node); - top.processed = true; - for (const Node* child : top.node->children()) { - stack.emplace(child, /*processed=*/false); - } - } - } - private: void build(Scope& scope, Node* parent) { for (auto scope_it = scope.exprs().begin(); scope_it != scope.exprs().end(); ++scope_it) { Expr* e = *scope_it; - auto [node_it, inserted] = nodes_.try_emplace(e, &scope, scope_it); + auto [node_it, inserted] = + nodes_.try_emplace(e, &scope, scope_it, parent); NVF_ERROR(inserted); Node& node = node_it->second; if (parent != nullptr) { @@ -131,7 +133,49 @@ class DominatorTree { } } - hir::HostIrContainer& hic_; + hir::HostIrContainer* hic_; + std::unordered_map nodes_; +}; + +class PostDominatorTree { + public: + explicit PostDominatorTree(hir::HostIrContainer& hic) : hic_(&hic) { + build(hic_->topLevel(), /*parent=*/nullptr); + } + + const Node* getRoot() const { + const auto& top_level_exprs = hic_->topLevelExprs(); + NVF_ERROR(!top_level_exprs.empty()); + Expr* root = top_level_exprs.back(); + return &nodes_.at(root); + } + + private: + void build(Scope& scope, Node* parent) { + auto& exprs = scope.exprs(); + for (auto it = exprs.end(); it != exprs.begin();) { + --it; + Expr* e = *it; + auto [node_it, inserted] = nodes_.try_emplace(e, &scope, it, parent); + NVF_ERROR(inserted); + Node& node = node_it->second; + if (parent != nullptr) { + parent->addChild(&node); + } + + if (auto* loop = dynamic_cast(e)) { + build(loop->body(), &node); + } + if (auto* ite = dynamic_cast(e)) { + build(ite->thenBody(), &node); + build(ite->elseBody(), &node); + } + + parent = &node; + } + } + + hir::HostIrContainer* hic_; std::unordered_map nodes_; }; @@ -157,9 +201,10 @@ void insertAllocations(hir::HostIrContainer& hic) { DominatorTree dom_tree(hic); std::unordered_set defined; - dom_tree.depthFirstTraverse( + depthFirstTraverse( + /*root=*/dom_tree.getRoot(), /*pre_fn=*/ - [&](const DominatorTree::Node* node) { + [&](const Node* node) { Expr* e = node->getExpr(); // If `e`'s output needs preallocation but isn't defined, insert an // allocation right before `e`. @@ -178,7 +223,7 @@ void insertAllocations(hir::HostIrContainer& hic) { } }, /*post_fn=*/ - [&](const DominatorTree::Node* node) { + [&](const Node* node) { Expr* e = node->getExpr(); for (auto* out : ir_utils::filterByType(e->outputs())) { defined.erase(out); @@ -186,9 +231,80 @@ void insertAllocations(hir::HostIrContainer& hic) { }); } +// For each TensorView that is allocated or used as an input, find its +// lowest common ancestor in the Post-dominator Tree — the latest point at which +// it can be deallocated. +class LowestCommonAncestor { + public: + explicit LowestCommonAncestor(const PostDominatorTree& pdt) : pdt_(&pdt) { + computeLcaMap(); + } + + const std::unordered_map& getLcaMap() const { + return lca_; + } + + private: + void computeLcaMap() { + int64_t current_depth = -1; + depthFirstTraverse( + /*root=*/pdt_->getRoot(), + /*pre_fn=*/ + [&](const Node* node) { + current_depth++; + NVF_ERROR(depth_.insert({node, current_depth}).second); + Expr* e = node->getExpr(); + + // Temporary special-case for kir::Allocate. We will switch + // inserting a new `hir::Allocate` in host IR lowering where + // the allocated `tv` will be the expr input. + if (auto* alloc = dynamic_cast(e)) { + auto* tv = alloc->buffer()->as(); + lca_[tv] = findLca(lca_[tv], node); + } + for (auto* tv : ir_utils::filterByType(e->inputs())) { + lca_[tv] = findLca(lca_[tv], node); + } + for (auto* tv : ir_utils::filterByType(e->outputs())) { + lca_[tv] = findLca(lca_[tv], node); + } + }, + /*post_fn=*/ + [&](const Node*) { --current_depth; }); + } + + const Node* findLca(const Node* a, const Node* b) const { + if (a == nullptr) { + return b; + } + if (b == nullptr) { + return a; + } + int64_t depth_a = depth_.at(a); + int64_t depth_b = depth_.at(b); + while (depth_a > depth_b) { + a = a->parent(); + depth_a--; + } + while (depth_b > depth_a) { + b = b->parent(); + depth_b--; + } + while (a != b) { + a = a->parent(); + b = b->parent(); + } + return a; + } + + const PostDominatorTree* pdt_; + std::unordered_map depth_; + std::unordered_map lca_; +}; + void insertDeallocations(hir::HostIrContainer& hic) { const std::list& top_level_exprs = hic.topLevelExprs(); - std::for_each(top_level_exprs.begin(), top_level_exprs.end(), [](Expr* expr) { + std::ranges::for_each(top_level_exprs, [](Expr* expr) { NVF_ERROR( !expr->isA(), "Expected hostir container to not have deallocate, but found one " @@ -196,40 +312,57 @@ void insertDeallocations(hir::HostIrContainer& hic) { expr); }); - // For each input in every expression in the container, find the position of - // its last use and insert a deallocate directly after, except for fusion - // inputs and outputs. - std::unordered_set last_use_found; - for (auto insertion_point = top_level_exprs.end(); - insertion_point != top_level_exprs.begin();) { - auto prev = std::prev(insertion_point); - Expr* e = *prev; - - // Only tensors need to be allocated. - for (auto* in : ir_utils::filterByType(e->inputs())) { - // Fusion inputs are managed by the caller. - if (in->isFusionInput()) { - continue; - } + PostDominatorTree pdt(hic); + LowestCommonAncestor lcas(pdt); - // Fusion outputs need to be kept alive for the caller. - if (in->isFusionOutput()) { - continue; - } + for (const auto& [tv, lca_node] : lcas.getLcaMap()) { + if (tv->isFusionInput() || tv->isFusionOutput()) { + continue; + } + NVF_ERROR( + lca_node != nullptr, + "Could not find least common ancestor for all uses of ", + tv); + auto* deallocate = IrBuilder::create(tv); + lca_node->scope()->insert(std::next(lca_node->iterator()), deallocate); + } +} - // Skip if `e` is not the last use. - if (!last_use_found.insert(in).second) { - continue; - } +void checkMemoryLeak(hir::HostIrContainer& hic) { + PostDominatorTree pdt(hic); + std::unordered_set allocated; - auto* deallocate = IrBuilder::create(in); - hic.insertExprBefore(insertion_point, deallocate); - } + depthFirstTraverse( + pdt.getRoot(), + /*pre_fn=*/ + [&](const Node* node) { + Expr* e = node->getExpr(); + if (auto* alloc = dynamic_cast(e)) { + allocated.insert(alloc->buffer()->as()); + } + for (auto* tv : ir_utils::filterByType(e->inputs())) { + allocated.insert(tv); + } + for (auto* tv : ir_utils::filterByType(e->outputs())) { + allocated.insert(tv); + } + }, + /*post_fn=*/ + [&](const Node* node) { + Expr* e = node->getExpr(); + if (auto* dealloc = dynamic_cast(e)) { + allocated.erase(dealloc->buffer()); + } + }); - // Don't `--insertion_point;` because we'd like to skip newly inserted - // deallocations. - insertion_point = prev; - } + NVF_ERROR( + std::ranges::all_of( + allocated, + [](TensorView* tv) { + return tv->isFusionInput() || tv->isFusionOutput(); + }), + "Memory leak detected in Host IR. Some TensorViews allocated in IR are " + "not deallocated and not fusion inputs/outputs."); } } // namespace @@ -240,6 +373,8 @@ void AllocateAndDeallocate::runPass(Fusion* fusion) { FusionGuard fg(hic); insertAllocations(*hic); insertDeallocations(*hic); + + checkMemoryLeak(*hic); } } // namespace nvfuser::hir