Skip to content

Commit 060370d

Browse files
committed
tests: Add memoized method tests
1 parent 9e7e0ff commit 060370d

File tree

1 file changed

+156
-1
lines changed

1 file changed

+156
-1
lines changed

tests/test_tools.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from concurrent.futures import ThreadPoolExecutor
2+
from threading import RLock
13
import numpy as np
24
import pytest
35
from sympy.abc import a, b, c, d, e
@@ -7,7 +9,7 @@
79
from devito import Operator, Eq
810
from devito.tools import (UnboundedMultiTuple, ctypes_to_cstr, toposort,
911
filter_ordered, transitive_closure, UnboundTuple,
10-
CacheInstances)
12+
CacheInstances, memoized_meth, memoized_generator)
1113
from devito.types.basic import Symbol
1214

1315

@@ -209,3 +211,156 @@ def __init__(self, value: int):
209211
# Cache should be cleared after Operator construction
210212
cache_size = Object._instance_cache.cache_info()[-1]
211213
assert cache_size == 0
214+
215+
216+
class TestMemoizedMethods:
217+
218+
def test_memoized_meth(self):
219+
"""
220+
Tests basic functionality of memoized_meth
221+
"""
222+
class Object:
223+
def __init__(self):
224+
self.misses = 0
225+
226+
@memoized_meth
227+
def compute(self, x):
228+
self.misses += 1
229+
return x * 2
230+
231+
obj = Object()
232+
obj.compute(2)
233+
obj.compute(4)
234+
assert obj.compute(2) == 4
235+
assert obj.compute(4) == 8
236+
assert obj.misses == 2 # Only two unique calls
237+
238+
def test_unhashable_args(self):
239+
"""
240+
Tests that memoized_meth raises an error for unhashable arguments.
241+
"""
242+
class Object:
243+
def __init__(self):
244+
self.misses = 0
245+
246+
@memoized_meth
247+
def compute(self, x: list[int]):
248+
self.misses += 1
249+
return sum(x)
250+
251+
obj = Object()
252+
with pytest.raises(TypeError):
253+
obj.compute([1, 2, 3])
254+
255+
@pytest.mark.parametrize('num_threads', [5, 11, 17])
256+
def test_memoized_meth_concurrency(self, num_threads: int):
257+
"""
258+
Tests concurrent calls to a memoized method
259+
"""
260+
# Each thread should have its own cache; the calls should not block
261+
class Object:
262+
def __init__(self):
263+
self.misses = 0
264+
self.lock = RLock()
265+
266+
@memoized_meth
267+
def compute(self, x):
268+
# print ID of the running thread
269+
with self.lock:
270+
self.misses += 1
271+
272+
# Simulate some computation
273+
time.sleep(0.2)
274+
return x * 2
275+
276+
obj = Object()
277+
def worker(x: int) -> int:
278+
a = obj.compute(x)
279+
b = obj.compute(x)
280+
assert a == b
281+
return a
282+
283+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
284+
stime = time.perf_counter()
285+
futures = [executor.submit(worker, i % 4) for i in range(num_threads)]
286+
results = [f.result() for f in futures]
287+
etime = time.perf_counter()
288+
289+
assert len(set(results)) == 4 # Should have gotten four unique results
290+
assert obj.misses == num_threads # Each thread should have missed once
291+
292+
# Ensure that the total time is approximately 0.2 seconds (one miss per thread)
293+
expected = 0.2
294+
assert abs(etime - stime - expected) < 0.1 * expected
295+
296+
def test_memoized_generator(self):
297+
"""
298+
Tests basic functionality of memoized_generator
299+
"""
300+
class Object:
301+
def __init__(self):
302+
self.misses = 0
303+
304+
@memoized_generator
305+
def compute(self, x):
306+
self.misses += 1
307+
yield x * 2
308+
yield x * 3
309+
310+
obj = Object()
311+
list(obj.compute(2))
312+
assert tuple(obj.compute(2)) == (4, 6)
313+
assert obj.misses == 1 # Only one unique call
314+
315+
@pytest.mark.parametrize('num_threads', [5, 11, 17])
316+
def test_memoized_generator_concurrency(self, num_threads: int):
317+
"""
318+
Tests concurrent calls to a memoized generator
319+
"""
320+
class Object:
321+
def __init__(self):
322+
self.misses = 0
323+
self.lock = RLock()
324+
325+
@memoized_generator
326+
def compute(self, x):
327+
with self.lock:
328+
self.misses += 1
329+
330+
time.sleep(0.25)
331+
yield x * 2
332+
333+
time.sleep(0.25)
334+
yield x * 3
335+
336+
# With memoized_generator, the initial construction should block but iteration
337+
# should be concurrent and reuse the same iterator.
338+
339+
obj = Object()
340+
def worker(x: int) -> list[int]:
341+
return list(obj.compute(x))
342+
343+
# If one thread consumes the generator, subsequent iteration shouldn't block
344+
# First we iterate concurrently; all but one thread should block to wait for
345+
# the producing thread, so all will take ~0.5 seconds
346+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
347+
stime = time.perf_counter()
348+
futures = [executor.submit(worker, i % 4) for i in range(num_threads)]
349+
results = [f.result() for f in futures]
350+
etime = time.perf_counter()
351+
352+
expected = 0.5
353+
assert abs(etime - stime - expected) < 0.1 * expected
354+
assert set(tuple(r) for r in results) == {(0, 0), (2, 3), (4, 6), (6, 9)}
355+
assert obj.misses == 4 # One miss per unique call
356+
357+
# Now iterating the same calls should use buffered generators from the cache
358+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
359+
stime = time.perf_counter()
360+
futures = [executor.submit(worker, i % 4) for i in range(num_threads)]
361+
results = [f.result() for f in futures]
362+
etime = time.perf_counter()
363+
364+
assert etime - stime < 0.1 # Should take epsilon time
365+
assert set(tuple(r) for r in results) == {(0, 0), (2, 3), (4, 6), (6, 9)}
366+
assert obj.misses == 4 # No new misses; all calls reused cached generators

0 commit comments

Comments
 (0)