Skip to content

Commit 03694d1

Browse files
committed
fix cherry-pick; DeviceAllocator is not part of 2.8
1 parent 12b6b94 commit 03694d1

File tree

1 file changed

+20
-27
lines changed

1 file changed

+20
-27
lines changed

aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,33 +37,6 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo
3737
allocator_->copy_data(dest, src, count);
3838
}
3939

40-
// From DeviceAllocator
41-
42-
bool initialized() override {
43-
return allocator_->initialized();
44-
}
45-
46-
void emptyCache(MempoolId_t mempool_id = {0, 0}) override {
47-
allocator_->emptyCache(mempool_id);
48-
}
49-
50-
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
51-
HIPStream hip_stream = HIPStream(stream);
52-
recordStream(ptr, hip_stream);
53-
}
54-
55-
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
56-
return allocator_->getDeviceStats(device);
57-
}
58-
59-
void resetAccumulatedStats(c10::DeviceIndex device) override {
60-
allocator_->resetAccumulatedStats(device);
61-
}
62-
63-
void resetPeakStats(c10::DeviceIndex device) override {
64-
allocator_->resetPeakStats(device);
65-
}
66-
6740
// From CUDAAllocator
6841

6942
void* raw_alloc(size_t nbytes) override {
@@ -82,6 +55,10 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo
8255
allocator_->init(device_count);
8356
}
8457

58+
bool initialized() override {
59+
return allocator_->initialized();
60+
}
61+
8562
double getMemoryFraction(c10::DeviceIndex device) override {
8663
return allocator_->getMemoryFraction(device);
8764
}
@@ -90,6 +67,10 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo
9067
allocator_->setMemoryFraction(fraction, device);
9168
}
9269

70+
void emptyCache(MempoolId_t mempool_id = {0, 0}) override {
71+
allocator_->emptyCache(mempool_id);
72+
}
73+
9374
void enable(bool value) override {
9475
allocator_->enable(value);
9576
}
@@ -110,6 +91,18 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo
11091
allocator_->recordStream(ptr, stream);
11192
}
11293

94+
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
95+
return allocator_->getDeviceStats(device);
96+
}
97+
98+
void resetAccumulatedStats(c10::DeviceIndex device) override {
99+
allocator_->resetAccumulatedStats(device);
100+
}
101+
102+
void resetPeakStats(c10::DeviceIndex device) override {
103+
allocator_->resetPeakStats(device);
104+
}
105+
113106
HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) override {
114107
return allocator_->snapshot(mempool_id);
115108
}

0 commit comments

Comments
 (0)