Skip to content

Commit 6952c1f

Browse files
committed
TEST: Check for competing actions on warning filters
1 parent e7e3668 commit 6952c1f

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

nibabel/tests/test_volumeutils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import itertools
2020
import gzip
2121
import bz2
22+
import threading
23+
import time
2224

2325
import numpy as np
2426

@@ -45,6 +47,7 @@
4547
rec2dict,
4648
_dt_min_max,
4749
_write_data,
50+
_ftype4scaled_finite,
4851
)
4952
from ..openers import Opener, BZ2File
5053
from ..casting import (floor_log2, type_info, OK_FLOATS, shared_range)
@@ -1245,3 +1248,42 @@ def read(self, n_bytes):
12451248
'Expected {0} bytes, got {1} bytes from {2}\n'
12461249
' - could the file be damaged?'.format(
12471250
11390625000000000000, 0, 'object'))
1251+
1252+
1253+
def test__ftype4scaled_finite_warningfilters():
1254+
# This test checks our ability to properly manage the thread-unsafe
1255+
# warnings filter list.
1256+
# 32MiB reliably produces the error on my machine; use 128 for safety
1257+
shape = (1024, 1024, 32)
1258+
tst_arr = np.zeros(shape, dtype=np.float32)
1259+
# Ensure that an overflow will happen
1260+
tst_arr[0, 0, 0] = np.finfo(np.float32).max
1261+
tst_arr[-1, -1, -1] = np.finfo(np.float32).min
1262+
go = threading.Event()
1263+
stop = threading.Event()
1264+
err = []
1265+
class MakeTotalDestroy(threading.Thread):
1266+
def run(self):
1267+
# Restore the warnings filters when we're done testing
1268+
with warnings.catch_warnings():
1269+
go.set()
1270+
while not stop.is_set():
1271+
warnings.filters[:] = []
1272+
time.sleep(0.01)
1273+
class CheckScaling(threading.Thread):
1274+
def run(self):
1275+
go.wait()
1276+
try:
1277+
# Use float16 to buy us two failures
1278+
_ftype4scaled_finite(tst_arr, 2.0, 1.0, default=np.float16)
1279+
except Exception as e:
1280+
err.append(e)
1281+
stop.set()
1282+
thread_a = CheckScaling()
1283+
thread_b = MakeTotalDestroy()
1284+
thread_a.start()
1285+
thread_b.start()
1286+
thread_a.join()
1287+
thread_b.join()
1288+
if err:
1289+
raise err[0]

0 commit comments

Comments
 (0)