diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 214e1ba0b53dd2..ac725494869ffa 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -2339,6 +2339,30 @@ def run_last(): self.assertIn("RuntimeError: can't register atexit after shutdown", err.decode()) +class TestAtomicCounter(unittest.TestCase): + def setUp(self): + self.counter = threading.AtomicCounter() + + def test_initial_value(self): + self.assertEqual(self.counter.get(), 0) + + def test_increment(self): + self.counter.inc() + self.assertEqual(self.counter.get(), 1) + self.counter.inc(5) + self.assertEqual(self.counter.get(), 6) + + def test_decrement(self): + self.counter.dec() + self.assertEqual(self.counter.get(), -1) + self.counter.dec(5) + self.assertEqual(self.counter.get(), -6) + + def test_increment_and_decrement(self): + self.counter.inc(10) + self.counter.dec(3) + self.assertEqual(self.counter.get(), 7) + if __name__ == "__main__": unittest.main() diff --git a/Lib/threading.py b/Lib/threading.py index 78e591124278fc..edae1459a117b8 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -29,7 +29,7 @@ 'Barrier', 'BrokenBarrierError', 'Timer', 'ThreadError', 'setprofile', 'settrace', 'local', 'stack_size', 'excepthook', 'ExceptHookArgs', 'gettrace', 'getprofile', - 'setprofile_all_threads','settrace_all_threads'] + 'setprofile_all_threads','settrace_all_threads', 'AtomicCounter'] # Rename some stuff so "from threading import *" is safe _start_joinable_thread = _thread.start_joinable_thread @@ -857,6 +857,31 @@ def _newname(name_template): _dangling = WeakSet() +class AtomicCounter: + """Threadsafe counter. + + Returns the value after inc/dec operations. + """ + + def __init__(self, initial=0): + self._value = initial + self._lock = Lock() + + def inc(self, n=1): + with self._lock: + self._value += n + return self._value + + def dec(self, n=1): + with self._lock: + self._value -= n + return self._value + + def get(self): + with self._lock: + return self._value + + # Main class for threads class Thread: