Skip to content

Commit 8700d1f

Browse files
committed
init
1 parent 4ac496f commit 8700d1f

File tree

2 files changed

+156
-16
lines changed

2 files changed

+156
-16
lines changed

ultraplot/internals/rcsetup.py

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
import functools
66
import re, matplotlib as mpl
7+
import threading
78
from collections.abc import MutableMapping
89
from numbers import Integral, Real
910

@@ -562,16 +563,42 @@ def _yaml_table(rcdict, comment=True, description=False):
562563

563564
class _RcParams(MutableMapping, dict):
564565
"""
565-
A simple dictionary with locked inputs and validated assignments.
566+
A thread-safe dictionary with validated assignments and thread-local storage used to store the configuration of UltraPlot.
567+
568+
It uses reentrant locks (RLock) to ensure that multiple threads can safely read and write to the configuration without causing data corruption.
569+
570+
Example
571+
-------
572+
>>> with rc_params:
573+
... rc_params['key'] = 'value' # Thread-local change
574+
... # Changes are automatically cleaned up when exiting the context
566575
"""
567576

568577
# NOTE: By omitting __delitem__ in MutableMapping we effectively
569578
# disable mutability. Also disables deleting items with pop().
570579
def __init__(self, source, validate):
571580
self._validate = validate
581+
self._lock = threading.RLock()
582+
self._local = threading.local()
583+
self._local.changes = {} # Initialize thread-local storage
584+
# Register all initial keys in the validation dictionary
585+
for key in source:
586+
if key not in validate:
587+
validate[key] = lambda x: x # Default validator
572588
for key, value in source.items():
573589
self.__setitem__(key, value) # trigger validation
574590

591+
def __enter__(self):
592+
"""Context manager entry - initialize thread-local storage if needed."""
593+
if not hasattr(self._local, "changes"):
594+
self._local.changes = {}
595+
return self
596+
597+
def __exit__(self, exc_type, exc_val, exc_tb):
598+
"""Context manager exit - clean up thread-local storage."""
599+
if hasattr(self._local, "changes"):
600+
del self._local.changes
601+
575602
def __repr__(self):
576603
return RcParams.__repr__(self)
577604

@@ -587,22 +614,33 @@ def __iter__(self):
587614
yield from sorted(dict.__iter__(self))
588615

589616
def __getitem__(self, key):
590-
key, _ = self._check_key(key)
591-
return dict.__getitem__(self, key)
617+
with self._lock:
618+
key, _ = self._check_key(key)
619+
# Check thread-local storage first
620+
if key in self._local.changes:
621+
return self._local.changes[key]
622+
# Check global dictionary (will raise KeyError if not found)
623+
return dict.__getitem__(self, key)
592624

593625
def __setitem__(self, key, value):
594-
key, value = self._check_key(key, value)
595-
if key not in self._validate:
596-
raise KeyError(f"Invalid rc key {key!r}.")
597-
try:
598-
value = self._validate[key](value)
599-
except (ValueError, TypeError) as error:
600-
raise ValueError(f"Key {key}: {error}") from None
601-
if key is not None:
602-
dict.__setitem__(self, key, value)
603-
604-
@staticmethod
605-
def _check_key(key, value=None):
626+
with self._lock:
627+
key, value = self._check_key(key, value)
628+
# Validate the value
629+
try:
630+
value = self._validate[key](value)
631+
except KeyError:
632+
# If key doesn't exist in validation, add it with default validator
633+
self._validate[key] = lambda x: x
634+
# Re-validate with new validator
635+
value = self._validate[key](value)
636+
except (ValueError, TypeError) as error:
637+
raise ValueError(f"Key {key}: {error}") from None
638+
if key is not None:
639+
# Store in both thread-local storage and main dictionary
640+
self._local.changes[key] = value
641+
dict.__setitem__(self, key, value)
642+
643+
def _check_key(self, key, value=None):
606644
# NOTE: If we assigned from the Configurator then the deprecated key will
607645
# still propagate to the same 'children' as the new key.
608646
# NOTE: This also translates values for special cases of renamed keys.
@@ -624,10 +662,21 @@ def _check_key(key, value=None):
624662
f"The rc setting {key!r} was removed in version {version}."
625663
+ (info and " " + info)
626664
)
665+
# Register new keys in the validation dictionary
666+
if key not in self._validate:
667+
self._validate[key] = lambda x: x # Default validator
627668
return key, value
628669

629670
def copy(self):
630-
source = {key: dict.__getitem__(self, key) for key in self}
671+
with self._lock:
672+
# Create a copy that includes both global and thread-local changes
673+
source = {}
674+
# Start with global values
675+
for key in self:
676+
if key not in self._local.changes:
677+
source[key] = dict.__getitem__(self, key)
678+
# Add thread-local changes
679+
source.update(self._local.changes)
631680
return _RcParams(source, self._validate)
632681

633682

ultraplot/tests/test_config.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import ultraplot as uplt, pytest
22
import importlib
3+
import threading
4+
import time
35

46

57
def test_wrong_keyword_reset():
@@ -96,6 +98,95 @@ def test_dev_version_skipped(mock_urlopen, mock_version, mock_print):
9698
mock_print.assert_not_called()
9799

98100

101+
def test_rcparams_thread_safety():
102+
"""
103+
Test that _RcParams is thread-safe when accessed concurrently.
104+
Each thread works with its own unique key to verify proper isolation.
105+
Thread-local changes are properly managed with context manager.
106+
"""
107+
# Create a new _RcParams instance for testing
108+
from ultraplot.internals.rcsetup import _RcParams
109+
110+
# Initialize with base keys
111+
base_keys = {f"base_key_{i}": f"base_value_{i}" for i in range(3)}
112+
rc_params = _RcParams(base_keys, {k: lambda x: x for k in base_keys})
113+
114+
# Number of threads and operations per thread
115+
num_threads = 5
116+
operations_per_thread = 20
117+
118+
# Each thread will work with its own unique key
119+
thread_keys = {}
120+
121+
def worker(thread_id):
122+
"""Thread function that works with its own unique key using context manager."""
123+
# Each thread gets its own unique key
124+
thread_key = f"thread_{thread_id}_key"
125+
thread_keys[thread_id] = thread_key
126+
127+
# Use context manager to ensure proper thread-local cleanup
128+
with rc_params:
129+
# Initialize the key with a base value
130+
rc_params[thread_key] = f"initial_{thread_id}"
131+
132+
# Perform operations
133+
for i in range(operations_per_thread):
134+
try:
135+
# Read the current value
136+
current = rc_params[thread_key]
137+
138+
# Update with new value
139+
new_value = f"thread_{thread_id}_value_{i}"
140+
rc_params[thread_key] = new_value
141+
142+
# Verify the update worked
143+
assert rc_params[thread_key] == new_value
144+
145+
# Also read some base keys to test mixed access
146+
if i % 5 == 0:
147+
base_key = f"base_key_{i % 3}"
148+
base_value = rc_params[base_key]
149+
assert isinstance(base_value, str)
150+
151+
except Exception as e:
152+
raise AssertionError(f"Thread {thread_id} failed: {str(e)}")
153+
154+
# Create and start threads
155+
threads = []
156+
for i in range(num_threads):
157+
t = threading.Thread(target=worker, args=(i,))
158+
threads.append(t)
159+
t.start()
160+
161+
# Wait for all threads to complete
162+
for t in threads:
163+
t.join()
164+
165+
# Verify each thread's key exists and has the expected final value
166+
for thread_id in range(num_threads):
167+
thread_key = thread_keys[thread_id]
168+
assert thread_key in rc_params, f"Thread {thread_id}'s key was lost"
169+
final_value = rc_params[thread_key]
170+
assert final_value == f"thread_{thread_id}_value_{operations_per_thread - 1}"
171+
172+
# Verify base keys are still intact
173+
for key, expected_value in base_keys.items():
174+
assert key in rc_params, f"Base key {key} was lost"
175+
assert rc_params[key] == expected_value, f"Base key {key} value was corrupted"
176+
177+
# Verify that thread-local changes are properly merged
178+
# Create a copy to verify the copy includes thread-local changes
179+
rc_copy = rc_params.copy()
180+
assert len(rc_copy) == len(base_keys) + num_threads, "Copy doesn't include all keys"
181+
182+
# Verify all keys are in the copy
183+
for key in base_keys:
184+
assert key in rc_copy, f"Base key {key} missing from copy"
185+
for thread_id in range(num_threads):
186+
thread_key = thread_keys[thread_id]
187+
assert thread_key in rc_copy, f"Thread {thread_id}'s key missing from copy"
188+
189+
99190
@pytest.mark.parametrize(
100191
"cycle, raises_error",
101192
[

0 commit comments

Comments
 (0)