Skip to content

Commit 5a99c88

Browse files
committed
Extract common tools
1 parent 0ab1db0 commit 5a99c88

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
ALLOCATION_UNIT_SIZE = 512
2+
3+
4+
class _DeviceMemStack:
5+
def __init__(self) -> None:
6+
self.allocations = []
7+
self.current = 0
8+
self.highwater = 0
9+
10+
def malloc(self, bytes):
11+
self.allocations.append(bytes)
12+
allocated = self._round_up(bytes)
13+
self.current += allocated
14+
self.highwater = max(self.current, self.highwater)
15+
16+
def free(self, bytes):
17+
assert bytes in self.allocations
18+
self.allocations.remove(bytes)
19+
self.current -= self._round_up(bytes)
20+
assert self.current >= 0
21+
22+
def _round_up(self, size):
23+
size = (size + ALLOCATION_UNIT_SIZE - 1) // ALLOCATION_UNIT_SIZE
24+
return size * ALLOCATION_UNIT_SIZE

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,13 @@ def data_file(test_data_path):
6060
def ensure_clean_memory():
6161
cp.get_default_memory_pool().free_all_blocks()
6262
cp.get_default_pinned_memory_pool().free_all_blocks()
63+
cache = cp.fft.config.get_plan_cache()
64+
cache.clear()
6365
yield None
6466
cp.get_default_memory_pool().free_all_blocks()
6567
cp.get_default_pinned_memory_pool().free_all_blocks()
68+
cache = cp.fft.config.get_plan_cache()
69+
cache.clear()
6670

6771

6872
@pytest.fixture
@@ -135,3 +139,20 @@ def host_detector_x(data_file):
135139
@pytest.fixture
136140
def detector_x(host_detector_x, ensure_clean_memory):
137141
return cp.asarray(host_detector_x)
142+
143+
144+
class MaxMemoryHook(cp.cuda.MemoryHook):
145+
def __init__(self, initial=0):
146+
self.max_mem = initial
147+
self.current = initial
148+
149+
def malloc_postprocess(
150+
self, device_id: int, size: int, mem_size: int, mem_ptr: int, pmem_id: int
151+
):
152+
self.current += mem_size
153+
self.max_mem = max(self.max_mem, self.current)
154+
155+
def free_postprocess(
156+
self, device_id: int, mem_size: int, mem_ptr: int, pmem_id: int
157+
):
158+
self.current -= mem_size

0 commit comments

Comments
 (0)