Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,7 @@ if(BUILD_TEST)
list(APPEND HOSTIR_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_host_ir_evaluator.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_integration.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_passes.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_stream_lowering.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_irs.cpp
)
Expand Down
289 changes: 195 additions & 94 deletions csrc/host_ir/allocate_and_deallocate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

#include "host_ir/allocate_and_deallocate.h"

#include <algorithm>
#include <functional>
#include <iterator>
#include <list>
#include <ranges>
#include <stack>
#include <unordered_map>
#include <unordered_set>
Expand All @@ -24,91 +24,101 @@ namespace nvfuser::hir {

namespace {

class DominatorTree {
class Node {
public:
class Node {
public:
Node(Scope* scope, Scope::Iterator iterator)
: scope_(scope), iterator_(iterator) {}
Node(const Node& other) = delete;
Node(Node&& other) = delete;
Node& operator=(const Node& other) = delete;
Node& operator=(Node&& other) = delete;

const std::vector<Node*>& children() const {
return children_;
}
Node(Scope* scope, Scope::Iterator iterator, const Node* parent)
: scope_(scope), iterator_(iterator), parent_(parent) {}
Node(const Node& other) = delete;
Node(Node&& other) = delete;
Node& operator=(const Node& other) = delete;
Node& operator=(Node&& other) = delete;

const std::vector<Node*>& children() const {
return children_;
}

void addChild(Node* child) {
children_.push_back(child);
}
void addChild(Node* child) {
children_.push_back(child);
}

Scope* scope() const {
return scope_;
}
Scope* scope() const {
return scope_;
}

Scope::Iterator iterator() const {
return iterator_;
}
Scope::Iterator iterator() const {
return iterator_;
}

Expr* getExpr() const {
return *iterator_;
}
Expr* getExpr() const {
return *iterator_;
}

private:
// Consider putting `scope` and `iterator` into a separate Mutator class.
// They are only needed when the user wants to modify the host IR.
Scope* scope_;
Scope::Iterator iterator_;
const Node* parent() const {
return parent_;
}

std::vector<Node*> children_;
private:
Scope* scope_;
Scope::Iterator iterator_;
const Node* parent_;
std::vector<Node*> children_;
};

// `pre_fn` is called before traversing any child of a node. `post_fn` is
// called after traversing all children of a node.
void depthFirstTraverseFromRoot(
const Node* root,
const std::function<void(const Node*)>& pre_fn,
const std::function<void(const Node*)>& post_fn) {
struct Frame {
const Node* node;
bool processed;
};

explicit DominatorTree(hir::HostIrContainer& hic) : hic_(hic) {
build(hic_.topLevel(), /*parent=*/nullptr);
std::stack<Frame> stack;
stack.push({root, /*processed=*/false});
while (!stack.empty()) {
Frame& top = stack.top();
if (top.processed) {
post_fn(top.node);
stack.pop();
continue;
}

pre_fn(top.node);
top.processed = true;
for (const Node* child : top.node->children()) {
stack.push({child, /*processed=*/false});
}
}
}

class DominatorTree {
public:
explicit DominatorTree(hir::HostIrContainer& hic) : hic_(&hic) {
build(hic_->topLevel(), /*parent=*/nullptr);
}

const Node* getRoot() const {
const auto& top_level_exprs = hic_.topLevelExprs();
const auto& top_level_exprs = hic_->topLevelExprs();
NVF_ERROR(!top_level_exprs.empty());
Expr* root = top_level_exprs.front();
return &nodes_.at(root);
}

// `pre_fn` is called before traversing any child of a node. `post_fn` is
// called after traversing all children of a node.
void depthFirstTraverse(
const std::function<void(const Node*)>& pre_fn,
const std::function<void(const Node*)>& post_fn) const {
struct Frame {
const Node* node;
bool processed;
};

std::stack<Frame> stack;
stack.emplace(getRoot(), /*processed=*/false);
while (!stack.empty()) {
Frame& top = stack.top();
if (top.processed) {
post_fn(top.node);
stack.pop();
continue;
}

pre_fn(top.node);
top.processed = true;
for (const Node* child : top.node->children()) {
stack.emplace(child, /*processed=*/false);
}
}
depthFirstTraverseFromRoot(getRoot(), pre_fn, post_fn);
}

private:
void build(Scope& scope, Node* parent) {
for (auto scope_it = scope.exprs().begin(); scope_it != scope.exprs().end();
++scope_it) {
Expr* e = *scope_it;
auto [node_it, inserted] = nodes_.try_emplace(e, &scope, scope_it);
auto [node_it, inserted] =
nodes_.try_emplace(e, &scope, scope_it, parent);
NVF_ERROR(inserted);
Node& node = node_it->second;
if (parent != nullptr) {
Expand All @@ -131,7 +141,60 @@ class DominatorTree {
}
}

hir::HostIrContainer& hic_;
hir::HostIrContainer* hic_;
std::unordered_map<const Expr*, Node> nodes_;
};

class PostDominatorTree {
public:
explicit PostDominatorTree(hir::HostIrContainer& hic) : hic_(&hic) {
build(hic_->topLevel(), /*parent=*/nullptr);
}

const Node* getRoot() const {
const auto& top_level_exprs = hic_->topLevelExprs();
NVF_ERROR(!top_level_exprs.empty());
Expr* root = top_level_exprs.back();
return &nodes_.at(root);
}

const Node* getNode(Expr* expr) const {
auto it = nodes_.find(expr);
return it != nodes_.end() ? &it->second : nullptr;
}

void depthFirstTraverse(
const std::function<void(const Node*)>& pre_fn,
const std::function<void(const Node*)>& post_fn) const {
depthFirstTraverseFromRoot(getRoot(), pre_fn, post_fn);
}

private:
void build(Scope& scope, Node* parent) {
auto& exprs = scope.exprs();
for (auto it = exprs.end(); it != exprs.begin();) {
--it;
Expr* e = *it;
auto [node_it, inserted] = nodes_.try_emplace(e, &scope, it, parent);
NVF_ERROR(inserted);
Node& node = node_it->second;
if (parent != nullptr) {
parent->addChild(&node);
}

if (auto* loop = dynamic_cast<hir::ForLoop*>(e)) {
build(loop->body(), &node);
}
if (auto* ite = dynamic_cast<kir::IfThenElse*>(e)) {
build(ite->thenBody(), &node);
build(ite->elseBody(), &node);
}

parent = &node;
}
}

hir::HostIrContainer* hic_;
std::unordered_map<const Expr*, Node> nodes_;
};

Expand Down Expand Up @@ -159,7 +222,7 @@ void insertAllocations(hir::HostIrContainer& hic) {

dom_tree.depthFirstTraverse(
/*pre_fn=*/
[&](const DominatorTree::Node* node) {
[&](const Node* node) {
Expr* e = node->getExpr();
// If `e`'s output needs preallocation but isn't defined, insert an
// allocation right before `e`.
Expand All @@ -178,57 +241,95 @@ void insertAllocations(hir::HostIrContainer& hic) {
}
},
/*post_fn=*/
[&](const DominatorTree::Node* node) {
[&](const Node* node) {
Expr* e = node->getExpr();
for (auto* out : ir_utils::filterByType<TensorView>(e->outputs())) {
defined.erase(out);
}
});
}

// For each TensorView that is allocated or used as an input, find its
// least common ancestor in the Post-dominator Tree — the latest point at which
// it can be deallocated.
std::unordered_map<TensorView*, const Node*> computeLeastCommonAncestor(
const PostDominatorTree& post_dom_tree) {
std::unordered_map<const Node*, int64_t> depth;

auto findLCA = [&](const Node* a, const Node* b) -> const Node* {
if (a == nullptr) {
return b;
}
if (b == nullptr) {
return a;
}
int64_t depth_a = depth.at(a);
int64_t depth_b = depth.at(b);
while (depth_a > depth_b) {
a = a->parent();
depth_a--;
}
while (depth_b > depth_a) {
b = b->parent();
depth_b--;
}
while (a != b) {
a = a->parent();
b = b->parent();
}
return a;
};

std::unordered_map<TensorView*, const Node*> lca;
int64_t current_depth = -1;

post_dom_tree.depthFirstTraverse(
/*pre_fn=*/
[&](const Node* node) {
current_depth++;
depth[node] = current_depth;
Expr* e = node->getExpr();

if (auto* alloc = dynamic_cast<kir::Allocate*>(e)) {
TensorView* tv = alloc->buffer()->as<TensorView>();
lca[tv] = findLCA(lca[tv], node);
}
for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) {
lca[in] = findLCA(lca[in], node);
}
},
/*post_fn=*/
[&](const Node*) { --current_depth; });

return lca;
}

void insertDeallocations(hir::HostIrContainer& hic) {
const std::list<Expr*>& top_level_exprs = hic.topLevelExprs();
std::for_each(top_level_exprs.begin(), top_level_exprs.end(), [](Expr* expr) {
std::ranges::for_each(top_level_exprs, [](Expr* expr) {
NVF_ERROR(
!expr->isA<hir::Deallocate>(),
"Expected hostir container to not have deallocate, but found one "
"anyways: ",
expr);
});

// For each input in every expression in the container, find the position of
// its last use and insert a deallocate directly after, except for fusion
// inputs and outputs.
std::unordered_set<TensorView*> last_use_found;
for (auto insertion_point = top_level_exprs.end();
insertion_point != top_level_exprs.begin();) {
auto prev = std::prev(insertion_point);
Expr* e = *prev;

// Only tensors need to be allocated.
for (auto* in : ir_utils::filterByType<TensorView>(e->inputs())) {
// Fusion inputs are managed by the caller.
if (in->isFusionInput()) {
continue;
}
PostDominatorTree post_dominator_tree(hic);
const std::unordered_map<TensorView*, const Node*>& lca_map =
computeLeastCommonAncestor(post_dominator_tree);

// Fusion outputs need to be kept alive for the caller.
if (in->isFusionOutput()) {
continue;
}

// Skip if `e` is not the last use.
if (!last_use_found.insert(in).second) {
continue;
}

auto* deallocate = IrBuilder::create<hir::Deallocate>(in);
hic.insertExprBefore(insertion_point, deallocate);
// Insert deallocate at LCA for each tensorview that is not a fusion input or
// output.
for (const auto& [tv, lca_node] : lca_map) {
if (tv->isFusionInput() || tv->isFusionOutput()) {
continue;
}

// Don't `--insertion_point;` because we'd like to skip newly inserted
// deallocations.
insertion_point = prev;
NVF_ERROR(
lca_node != nullptr,
"Could not find least common ancestor for all uses of ",
tv);
auto* deallocate = IrBuilder::create<hir::Deallocate>(tv);
lca_node->scope()->insert(std::next(lca_node->iterator()), deallocate);
}
}

Expand Down
Loading
Loading