Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class Val;
f(ShareMemHandles); \
f(HirAliasSelect); \
f(ShardByStream); \
f(Allocate); \
f(Deallocate); \
f(ForLoop); \
f(SymmetricContiguousView);
Expand Down
3 changes: 2 additions & 1 deletion csrc/host_ir/allocate_and_deallocate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <unordered_set>
#include <vector>

#include "host_ir/ir.h"
#include "ir/builder.h"
#include "ir/utils.h"

Expand Down Expand Up @@ -170,7 +171,7 @@ void insertAllocations(hir::HostIrContainer& hic) {

if (needsOutputPreallocation(e)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(out, out->getMemoryType());
IrBuilder::create<hir::Allocate>(out, out->getMemoryType());
node->scope()->insert(node->iterator(), allocate);
}

Expand Down
22 changes: 22 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,28 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) {
expr_evaluator_.bind(tv, tensor);
}

void HostIrEvaluator::handle(hir::Allocate* allocate) {
FUSER_PERF_SCOPE("HostIrEvaluator::handle(Allocate)");
auto* tv = allocate->in();

GlobalBufferInfo info =
getBufferInfos(expr_evaluator_, PrimDataType::Int, {tv}).at(0);
c10::Device device =
communicator_ ? communicator_->device() : at::Device("cuda:0");
at::Tensor tensor = at::native::empty_strided_cuda(
info.shape_info.logical_sizes,
info.shape_info.logical_strides,
info.type,
c10::nullopt,
device,
c10::nullopt);

if (allocate->zeroInit()) {
tensor.zero_();
}
expr_evaluator_.bind(tv, tensor);
}

void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) {
auto indexed_id =
hir_alias_select->in()->getLogicalDomain().at(hir_alias_select->axis());
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch {
void handle(MatmulOp*) override;
void handle(LinearOp*) override;
void handle(kir::Allocate*) override;
void handle(Allocate*) override;
void handle(LoadStoreOp*) override;
void handle(BinaryOp*) override;
void handle(ReductionOp*) override;
Expand Down
34 changes: 34 additions & 0 deletions csrc/host_ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,38 @@ std::string ForLoop::toInlineString(int indent_size) const {
index, iter_domain->start(), iter_domain->stop());
}

Allocate::Allocate(
IrBuilderPasskey passkey,
TensorView* in,
MemoryType memory_type,
bool zero_init)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
NVF_ERROR(in->isA<TensorView>(), "hir::Allocate input must be a TensorView.");

addInput(in);
addDataAttribute(memory_type);
addDataAttribute(zero_init);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Allocate)

std::string Allocate::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << in()->toString() << " = ALLOCATE("
<< "mem_type=" << memoryType() << ", "
<< "zero_init=" << boolLiteral(zeroInit()) << ")"
<< std::endl;
return ss.str();
}

std::string Allocate::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << in()->toInlineString() << " = ALLOCATE("
<< "mem_type=" << memoryType() << ", "
<< "zero_init=" << boolLiteral(zeroInit()) << ")";
return ss.str();
}

} // namespace nvfuser::hir
36 changes: 36 additions & 0 deletions csrc/host_ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,42 @@ class LaunchKernel : public Expr {
CompiledKernel* compiled_kernel_ = nullptr;
};

class Allocate : public Expr {
public:
using Expr::Expr;

explicit Allocate(
IrBuilderPasskey passkey,
TensorView* in,
MemoryType memory_type,
bool zero_init = false);

Allocate(const Allocate& other) = delete;
Allocate& operator=(const Allocate& other) = delete;
Allocate(Allocate&& other) = delete;
Allocate& operator=(Allocate&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::Allocate";
}

TensorView* in() const {
return inputs().at(0);
}

MemoryType memoryType() const {
return attribute<MemoryType>(0);
}

bool zeroInit() const {
return attribute<bool>(1);
}
};

class Deallocate : public Expr {
public:
using Expr::Expr;
Expand Down
6 changes: 3 additions & 3 deletions csrc/host_ir/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void lowerSegment(

// Allocate the recv buffers of communications
auto* allocate =
IrBuilder::create<kir::Allocate>(out, out->getMemoryType());
IrBuilder::create<hir::Allocate>(out, out->getMemoryType());
if (getShardedIterDomain(
out, ParallelType::Stream, DomainType::kLoop) != nullptr &&
getShardedIterDomain(
Expand Down Expand Up @@ -311,7 +311,7 @@ void lowerSegment(
out, ParallelType::Stream, DomainType::kAllocation) ==
nullptr) {
auto* allocate =
IrBuilder::create<kir::Allocate>(out, out->getMemoryType());
IrBuilder::create<hir::Allocate>(out, out->getMemoryType());
innermost.parent_scope->insert(
innermost.parent_insertion_point, allocate);
// Loop is stream parallelized but allocation is not. Therefore,
Expand Down Expand Up @@ -348,7 +348,7 @@ void lowerSegment(
alias);

auto* allocate =
IrBuilder::create<kir::Allocate>(out_tv, out_tv->getMemoryType());
IrBuilder::create<hir::Allocate>(out_tv, out_tv->getMemoryType());
innermost_scope.pushBack(allocate);
}

Expand Down
Loading