Skip to content

Commit c72c54b

Browse files
laramielcopybara-github
authored andcommitted
Add some multi-threading tests to tensorstore python
The new tests pass when built with the python GIL. The new tests fail when built in free-threaded mode and TSAN. PiperOrigin-RevId: 824918753 Change-Id: I4d006d41660fcecc0e35c4c7050ec3214ad0f084
1 parent af2bd4f commit c72c54b

File tree

7 files changed

+450
-2
lines changed

7 files changed

+450
-2
lines changed

python/tensorstore/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ tensorstore_pytest_test(
143143
name = "dim_test",
144144
size = "small",
145145
srcs = ["tests/dim_test.py"],
146+
tags = ["cpu:2"],
146147
deps = [
147148
":conftest",
148149
":tensorstore",
@@ -752,6 +753,7 @@ tensorstore_pytest_test(
752753
name = "spec_test",
753754
size = "small",
754755
srcs = ["tests/spec_test.py"],
756+
tags = ["cpu:2"],
755757
deps = [
756758
":conftest",
757759
":tensorstore",
@@ -774,6 +776,7 @@ tensorstore_pytest_test(
774776
name = "tensorstore_test",
775777
size = "small",
776778
srcs = ["tests/tensorstore_test.py"],
779+
tags = ["cpu:2"],
777780
deps = [
778781
":conftest",
779782
":tensorstore",
@@ -794,8 +797,9 @@ tensorstore_pytest_test(
794797

795798
tensorstore_pytest_test(
796799
name = "future_test",
797-
size = "small",
800+
size = "medium",
798801
srcs = ["tests/future_test.py"],
802+
tags = ["cpu:2"],
799803
deps = [
800804
":conftest",
801805
":tensorstore",
@@ -1100,6 +1104,7 @@ tensorstore_pytest_test(
11001104
name = "kvstore_test",
11011105
size = "small",
11021106
srcs = ["tests/kvstore_test.py"],
1107+
tags = ["cpu:2"],
11031108
deps = [
11041109
":conftest",
11051110
":tensorstore",

python/tensorstore/generate_type_stubs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,13 @@ def transform_init_ast(
122122
#
123123
# These are internal implementation details that aren't useful in the type
124124
# stubs.
125-
excluded_symbols = {"_unpickle", "__reduce__", "__getstate__", "__setstate__"}
125+
excluded_symbols = {
126+
"_unpickle",
127+
"__reduce__",
128+
"__getstate__",
129+
"__setstate__",
130+
"__conditional_annotations__",
131+
}
126132

127133
class _InitPyVisitor(ast.NodeTransformer):
128134

python/tensorstore/tests/dim_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"""Tests for tensorstore.Dim"""
1515

1616
import pickle
17+
import threading
18+
import time
1719

1820
import pytest
1921
import tensorstore as ts
@@ -117,3 +119,48 @@ def test_hull() -> None:
117119
def test_pickle() -> None:
118120
x = ts.Dim(inclusive_min=3, size=10)
119121
assert pickle.loads(pickle.dumps(x)) == x
122+
123+
124+
def test_dim_concurrent() -> None:
125+
"""Tests concurrent access to Dim properties."""
126+
dim = ts.Dim()
127+
128+
stop = threading.Event()
129+
130+
def read_props() -> None:
131+
while not stop.is_set():
132+
_ = dim.inclusive_min
133+
_ = dim.implicit_lower
134+
_ = dim.implicit_upper
135+
_ = dim.label
136+
_ = dim == ts.Dim()
137+
_ = f"{dim}"
138+
_ = repr(dim)
139+
140+
def update_props() -> None:
141+
time.sleep(0.01)
142+
i = 0
143+
while not stop.is_set():
144+
if (i % 2) == 0:
145+
dim.implicit_lower = True
146+
dim.implicit_upper = False
147+
dim.label = "x"
148+
else:
149+
dim.implicit_lower = False
150+
dim.implicit_upper = True
151+
dim.label = ""
152+
i += 1
153+
154+
threads = []
155+
for _ in range(4):
156+
threads.append(threading.Thread(target=read_props))
157+
threads.append(threading.Thread(target=update_props))
158+
159+
for t in threads:
160+
t.start()
161+
162+
time.sleep(0.3)
163+
stop.set()
164+
165+
for t in threads:
166+
t.join()

python/tensorstore/tests/future_test.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import asyncio
1616
import os
1717
import pickle
18+
import random
1819
import signal
1920
import threading
2021
import time
@@ -233,3 +234,147 @@ def callback(f: ts.Future) -> None:
233234
del p
234235
assert f.done()
235236
assert exc is not None
237+
238+
239+
def _get_delay() -> float:
240+
return random.random() * 0.01
241+
242+
243+
def _run_threads(
244+
promise: ts.Promise,
245+
i: int,
246+
stop: threading.Event,
247+
threads: list[threading.Thread],
248+
) -> None:
249+
"""Runs a list of threads concurrently with setting the promise result."""
250+
251+
delay = _get_delay()
252+
for t in threads:
253+
t.start()
254+
255+
time.sleep(delay)
256+
try:
257+
promise.set_result(i)
258+
except: # pylint: disable=bare-except
259+
pass
260+
261+
time.sleep(0.01)
262+
stop.set()
263+
264+
for t in threads:
265+
t.join()
266+
267+
268+
def test_future_concurrent_set_cancel_callback() -> None:
269+
"""Test multi-treaded races between adding callbacks and cancellation."""
270+
271+
def _concurrent_set_cancel_callback(i: int) -> None:
272+
callback_called = threading.Event()
273+
callback_result = [None]
274+
callback_error = [None]
275+
276+
def callback(f: ts.Future) -> None:
277+
nonlocal callback_called
278+
nonlocal callback_result
279+
nonlocal callback_error
280+
try:
281+
callback_result[0] = f.result()
282+
except asyncio.CancelledError as e:
283+
callback_error[0] = e
284+
except Exception as e: # pylint: disable=broad-exception-caught
285+
callback_error[0] = e
286+
finally:
287+
callback_called.set()
288+
289+
def do_add_callback(delay: float, f: ts.Future) -> None:
290+
time.sleep(delay)
291+
f.add_done_callback(callback)
292+
293+
def do_cancel(delay: float, f: ts.Future) -> None:
294+
time.sleep(delay)
295+
f.cancel()
296+
297+
promise, future = ts.Promise.new()
298+
_run_threads(
299+
promise,
300+
i,
301+
threading.Event(), # unused
302+
[
303+
threading.Thread(target=do_cancel, args=(_get_delay(), future)),
304+
threading.Thread(
305+
target=do_add_callback, args=(_get_delay(), future)
306+
),
307+
],
308+
)
309+
310+
assert future.done()
311+
assert callback_called.wait(timeout=5)
312+
313+
if future.cancelled():
314+
assert isinstance(callback_error[0], asyncio.CancelledError)
315+
with pytest.raises(asyncio.CancelledError):
316+
future.result()
317+
else:
318+
# If not cancelled, result must have been set.
319+
assert callback_error[0] is None
320+
assert future.result() == i
321+
assert callback_result[0] == i
322+
323+
for i in range(20):
324+
_concurrent_set_cancel_callback(i)
325+
326+
327+
def test_future_concurrent_ops() -> None:
328+
"""Test multi-treaded races between adding callbacks and setting result."""
329+
330+
def _concurrent_ops() -> None:
331+
events = [threading.Event() for _ in range(8)]
332+
results = {}
333+
lock = threading.Lock()
334+
stop = threading.Event()
335+
336+
def make_callback(
337+
idx: int, event: threading.Event
338+
) -> Callable[[ts.Future], None]:
339+
def callback(f: ts.Future) -> None:
340+
try:
341+
result = f.result()
342+
with lock:
343+
results[idx] = result
344+
finally:
345+
event.set()
346+
347+
return callback
348+
349+
promise, future = ts.Promise.new()
350+
351+
def do_add_callback(
352+
delay: float, callback: Callable[[ts.Future], None]
353+
) -> None:
354+
time.sleep(delay)
355+
future.add_done_callback(callback)
356+
357+
def read_props() -> None:
358+
while not stop.is_set():
359+
_ = future.done()
360+
_ = future.cancelled()
361+
362+
threads = []
363+
for i in range(len(events)):
364+
threads.append(
365+
threading.Thread(
366+
target=do_add_callback,
367+
368+
args=(_get_delay(), make_callback(i, events[i])),
369+
)
370+
)
371+
threads.append(threading.Thread(target=read_props))
372+
373+
_run_threads(promise, 42, stop, threads)
374+
375+
for i in range(len(events)):
376+
assert events[i].wait(timeout=5)
377+
assert results[i] == 42
378+
379+
for _ in range(20):
380+
_concurrent_ops()

python/tensorstore/tests/kvstore_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import pathlib
1818
import pickle
1919
import tempfile
20+
import threading
21+
import time
22+
from typing import Callable
2023

2124
import pytest
2225
import tensorstore as ts
@@ -157,3 +160,78 @@ def test_copy_range_to_ocdbt_memory_bad_path() -> None:
157160
).result()
158161
with pytest.raises(NotImplementedError):
159162
child.experimental_copy_range_to(parent).result()
163+
164+
165+
def _run_threads(
166+
stop: threading.Event,
167+
read_props: Callable[[], None],
168+
update_props: Callable[[], None],
169+
) -> None:
170+
threads = []
171+
for _ in range(4):
172+
threads.append(threading.Thread(target=read_props))
173+
threads.append(threading.Thread(target=update_props))
174+
175+
for t in threads:
176+
t.start()
177+
178+
time.sleep(0.3)
179+
stop.set()
180+
181+
for t in threads:
182+
t.join()
183+
184+
185+
def test_kvstore_spec_concurrent_update_and_read() -> None:
186+
"""Validates that concurrent updates and reads do not crash."""
187+
s = ts.KvStore.Spec('memory://')
188+
189+
stop = threading.Event()
190+
191+
def read_props() -> None:
192+
while not stop.is_set():
193+
_ = s.path
194+
_ = s.url
195+
_ = s.base
196+
_ = s == ts.KvStore.Spec('memory://')
197+
_ = s.to_json()
198+
_ = f'{s}'
199+
_ = repr(s)
200+
201+
def update_props() -> None:
202+
time.sleep(0.01)
203+
i = 0
204+
while not stop.is_set():
205+
if (i % 2) == 0:
206+
s.path = 'abc/'
207+
else:
208+
s.path = 'def/'
209+
i += 1
210+
211+
_run_threads(stop, read_props, update_props)
212+
213+
214+
def test_kvstore_keyrange_concurrent() -> None:
215+
"""Tests concurrent access to KvStore.KeyRange properties."""
216+
kr = ts.KvStore.KeyRange('a', 'z')
217+
218+
stop = threading.Event()
219+
220+
def read_props() -> None:
221+
while not stop.is_set():
222+
_ = kr.inclusive_min
223+
_ = kr.exclusive_max
224+
_ = kr.empty
225+
_ = kr == ts.KvStore.KeyRange('a', 'z')
226+
_ = f'{kr}'
227+
_ = repr(kr)
228+
229+
def update_props() -> None:
230+
time.sleep(0.01)
231+
while not stop.is_set():
232+
kr.inclusive_min = 'b'
233+
kr.exclusive_max = 'y'
234+
kr.inclusive_min = 'a'
235+
kr.exclusive_max = 'z'
236+
237+
_run_threads(stop, read_props, update_props)

0 commit comments

Comments
 (0)