Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 202 additions & 5 deletions aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <c10/core/Allocator.h>
#include <c10/core/DeviceType.h>
#include <c10/hip/HIPCachingAllocator.h>

// Use of c10::hip namespace here makes hipification easier, because
// I don't have to also fix namespaces. Sorry!
Expand All @@ -10,22 +9,220 @@ namespace c10::hip {
// Takes a valid HIPAllocator (of any sort) and turns it into
// an allocator pretending to be a CUDA allocator. See
// Note [Masquerading as CUDA]
class HIPAllocatorMasqueradingAsCUDA final : public Allocator {
Allocator* allocator_;
class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
HIPCachingAllocator::HIPAllocator* allocator_;
public:
explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator)
explicit HIPAllocatorMasqueradingAsCUDA(HIPCachingAllocator::HIPAllocator* allocator)
: allocator_(allocator) {}

virtual ~HIPAllocatorMasqueradingAsCUDA() = default;

// From c10::Allocator

DataPtr allocate(size_t size) override {
DataPtr r = allocator_->allocate(size);
r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index()));
return r;
}

bool is_simple_data_ptr(const DataPtr& data_ptr) const override {
return allocator_->is_simple_data_ptr(data_ptr);
}

DeleterFnPtr raw_deleter() const override {
return allocator_->raw_deleter();
}

void copy_data(void* dest, const void* src, std::size_t count) const final {
allocator_->copy_data(dest, src, count);
}

// From CUDAAllocator

void* raw_alloc(size_t nbytes) override {
return allocator_->raw_alloc(nbytes);
}

void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override {
return allocator_->raw_alloc_with_stream(nbytes, stream);
}

void raw_delete(void* ptr) override {
allocator_->raw_delete(ptr);
}

void init(int device_count) override {
allocator_->init(device_count);
}

bool initialized() override {
return allocator_->initialized();
}

double getMemoryFraction(c10::DeviceIndex device) override {
return allocator_->getMemoryFraction(device);
}

void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
allocator_->setMemoryFraction(fraction, device);
}

void emptyCache(MempoolId_t mempool_id = {0, 0}) override {
allocator_->emptyCache(mempool_id);
}

void enable(bool value) override {
allocator_->enable(value);
}

bool isEnabled() const override {
return allocator_->isEnabled();
}

void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
allocator_->cacheInfo(device, largestBlock);
}

void* getBaseAllocation(void* ptr, size_t* size) override {
return allocator_->getBaseAllocation(ptr, size);
}

void recordStream(const DataPtr& ptr, HIPStream stream) override {
allocator_->recordStream(ptr, stream);
}

CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
return allocator_->getDeviceStats(device);
}

void resetAccumulatedStats(c10::DeviceIndex device) override {
allocator_->resetAccumulatedStats(device);
}

void resetPeakStats(c10::DeviceIndex device) override {
allocator_->resetPeakStats(device);
}

HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) override {
return allocator_->snapshot(mempool_id);
}

void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(hipStream_t)> filter) override {
allocator_->beginAllocateToPool(device, mempool_id, filter);
}

void endAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id) override {
allocator_->endAllocateToPool(device, mempool_id);
}

void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
allocator_->releasePool(device, mempool_id);
}

int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) override {
return allocator_->getPoolUseCount(device, mempool_id);
}

void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
HIPAllocator* allocator = nullptr) override {
allocator_->createOrIncrefPool(device, mempool_id, allocator);
}

void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
allocator_->setUseOnOOM(device, mempool_id);
}

bool checkPoolLiveAllocations(
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) override {
return allocator_->checkPoolLiveAllocations(device, mempool_id, expected_live_allocations);
}

HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) override {
return allocator_->shareIpcHandle(ptr);
}

std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
return allocator_->getIpcDevPtr(handle);
}

bool isHistoryEnabled() override {
return allocator_->isHistoryEnabled();
}

void recordHistory(
bool enabled,
HIPCachingAllocator::CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
HIPCachingAllocator::RecordContext when,
bool clearHistory) override {
allocator_->recordHistory(enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
}

void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& md) override {
allocator_->recordAnnotation(md);
}

void pushCompileContext(std::string& md) override {
allocator_->pushCompileContext(md);
}

void popCompileContext() override {
allocator_->popCompileContext();
}

void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) override {
allocator_->attachOutOfMemoryObserver(observer);
}

void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) override {
allocator_->attachAllocatorTraceTracker(tracker);
}

void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
allocator_->enablePeerAccess(dev, dev_to_access);
}

hipError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
hipStream_t stream,
bool p2p_enabled) override {
return allocator_->memcpyAsync(dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
}

std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) override {
return allocator_->getCheckpointState(device, id);
}

HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
auto cpd = allocator_->setCheckpointPoolState(device, pps);
for (auto& ptr : cpd.dataptrs_allocd) {
ptr.unsafe_set_device(Device(c10::DeviceType::CUDA, ptr.device().index()));
}
return cpd;
}

std::string name() override {
return allocator_->name();
}

};

} // namespace c10::hip
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#include <c10/core/Allocator.h>
#include <c10/hip/HIPCachingAllocator.h>
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>

namespace c10 { namespace hip {
namespace HIPCachingAllocatorMasqueradingAsCUDA {

Allocator* get() {
HIPCachingAllocator::HIPAllocator* get() {
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
return &allocator;
}
Expand Down
Loading