Skip to content

Commit ca9f4ba

Browse files
Merge pull request #25 from GlenWalker/thread_races
Fix race conditions for attribute creation
2 parents 312f892 + 6031827 commit ca9f4ba

File tree

2 files changed

+151
-5
lines changed

2 files changed

+151
-5
lines changed

src/apipkg/__init__.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
import sys
1010
from types import ModuleType
1111

12+
# Prior to Python 3.7 threading support was optional
13+
try:
14+
import threading
15+
except ImportError:
16+
threading = None
17+
else:
18+
import functools
19+
1220
from .version import version as __version__ # NOQA:F401
1321

1422

@@ -83,6 +91,22 @@ def importobj(modpath, attrname):
8391
return retval
8492

8593

94+
def _synchronized(wrapped_function):
95+
"""Decorator to synchronise __getattr__ calls."""
96+
if threading is None:
97+
return wrapped_function
98+
99+
# Lock shared between all instances of ApiModule to avoid possible deadlocks
100+
lock = threading.RLock()
101+
102+
@functools.wraps(wrapped_function)
103+
def synchronized_wrapper_function(*args, **kwargs):
104+
with lock:
105+
return wrapped_function(*args, **kwargs)
106+
107+
return synchronized_wrapper_function
108+
109+
86110
class ApiModule(ModuleType):
87111
"""the magical lazy-loading module standing"""
88112

@@ -105,7 +129,6 @@ def __init__(self, name, importspec, implprefix=None, attr=None):
105129
self.__implprefix__ = implprefix or name
106130
if attr:
107131
for name, val in attr.items():
108-
# print "setting", self.__name__, name, val
109132
setattr(self, name, val)
110133
for name, importspec in importspec.items():
111134
if isinstance(importspec, dict):
@@ -139,19 +162,32 @@ def __repr__(self):
139162
return "<ApiModule {!r} {}>".format(self.__name__, " ".join(repr_list))
140163
return "<ApiModule {!r}>".format(self.__name__)
141164

142-
def __makeattr(self, name):
165+
@_synchronized
166+
def __makeattr(self, name, isgetattr=False):
143167
"""lazily compute value for name or raise AttributeError if unknown."""
144-
# print "makeattr", self.__name__, name
145168
target = None
146169
if "__onfirstaccess__" in self.__map__:
147170
target = self.__map__.pop("__onfirstaccess__")
148171
importobj(*target)()
149172
try:
150173
modpath, attrname = self.__map__[name]
151174
except KeyError:
175+
# __getattr__ is called when the attribute does not exist, but it may have
176+
# been set by the onfirstaccess call above. Infinite recursion is not
177+
# possible as __onfirstaccess__ is removed before the call (unless the call
178+
# adds __onfirstaccess__ to __map__ explicitly, which is not our problem)
152179
if target is not None and name != "__onfirstaccess__":
153-
# retry, onfirstaccess might have set attrs
154180
return getattr(self, name)
181+
# Attribute may also have been set during a concurrent call to __getattr__
182+
# which executed after this call was already waiting on the lock. Check
183+
# for a recently set attribute while avoiding infinite recursion:
184+
# * Don't call __getattribute__ if __makeattr was called from a data
185+
# descriptor such as the __doc__ or __dict__ properties, since data
186+
# descriptors are called as part of object.__getattribute__
187+
# * Only call __getattribute__ if there is a possibility something has set
188+
# the attribute we're looking for since __getattr__ was called
189+
if threading is not None and isgetattr:
190+
return super(ApiModule, self).__getattribute__(name)
155191
raise AttributeError(name)
156192
else:
157193
result = importobj(modpath, attrname)
@@ -162,7 +198,8 @@ def __makeattr(self, name):
162198
pass # in a recursive-import situation a double-del can happen
163199
return result
164200

165-
__getattr__ = __makeattr
201+
def __getattr__(self, name):
202+
return self.__makeattr(name, isgetattr=True)
166203

167204
@property
168205
def __dict__(self):

test_apipkg.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import textwrap
55
import types
66

7+
try:
8+
import threading
9+
except ImportError:
10+
pass
11+
712
import pytest
813

914
import apipkg
@@ -456,6 +461,110 @@ def init():
456461
assert "__onfirstaccess__" not in vars(mod)
457462

458463

464+
@pytest.mark.skipif("threading" not in sys.modules, reason="requires thread support")
465+
def test_onfirstaccess_race(tmpdir, monkeypatch):
466+
pkgdir = tmpdir.mkdir("firstaccessrace")
467+
pkgdir.join("__init__.py").write(
468+
textwrap.dedent(
469+
"""
470+
import apipkg
471+
apipkg.initpkg(__name__, exportdefs={
472+
'__onfirstaccess__': '.submod:init',
473+
'l': '.submod:l',
474+
},
475+
)
476+
"""
477+
)
478+
)
479+
pkgdir.join("submod.py").write(
480+
textwrap.dedent(
481+
"""
482+
import time
483+
l = []
484+
def init():
485+
time.sleep(0.1)
486+
l.append(1)
487+
"""
488+
)
489+
)
490+
monkeypatch.syspath_prepend(tmpdir)
491+
import firstaccessrace
492+
493+
assert isinstance(firstaccessrace, apipkg.ApiModule)
494+
495+
class TestThread(threading.Thread):
496+
def __init__(self, event_start):
497+
super(TestThread, self).__init__()
498+
self.event_start = event_start
499+
self.lenl = None
500+
501+
def run(self):
502+
self.event_start.wait()
503+
self.lenl = len(firstaccessrace.l)
504+
505+
event_start = threading.Event()
506+
threads = [TestThread(event_start) for _ in range(8)]
507+
for thread in threads:
508+
thread.start()
509+
event_start.set()
510+
for thread in threads:
511+
thread.join()
512+
assert len(firstaccessrace.l) == 1
513+
for thread in threads:
514+
assert thread.lenl == 1
515+
assert "__onfirstaccess__" not in firstaccessrace.__all__
516+
517+
518+
@pytest.mark.skipif("threading" not in sys.modules, reason="requires thread support")
519+
def test_attribute_race(tmpdir, monkeypatch):
520+
pkgdir = tmpdir.mkdir("attributerace")
521+
pkgdir.join("__init__.py").write(
522+
textwrap.dedent(
523+
"""
524+
import apipkg
525+
apipkg.initpkg(__name__, exportdefs={
526+
'attr': '.submod:attr',
527+
},
528+
)
529+
"""
530+
)
531+
)
532+
pkgdir.join("submod.py").write(
533+
textwrap.dedent(
534+
"""
535+
import time
536+
time.sleep(0.1)
537+
attr = 42
538+
"""
539+
)
540+
)
541+
monkeypatch.syspath_prepend(tmpdir)
542+
import attributerace
543+
544+
assert isinstance(attributerace, apipkg.ApiModule)
545+
546+
class TestThread(threading.Thread):
547+
def __init__(self, event_start):
548+
super(TestThread, self).__init__()
549+
self.event_start = event_start
550+
self.attr = None
551+
552+
def run(self):
553+
self.event_start.wait()
554+
self.attr = attributerace.attr
555+
556+
event_start = threading.Event()
557+
threads = [TestThread(event_start) for _ in range(8)]
558+
for thread in threads:
559+
thread.start()
560+
event_start.set()
561+
for thread in threads:
562+
thread.join()
563+
assert attributerace.attr == 42
564+
for thread in threads:
565+
assert thread.attr == 42
566+
567+
459568
def test_bpython_getattr_override(tmpdir, monkeypatch):
460569
def patchgetattr(self, name):
461570
raise AttributeError(name)

0 commit comments

Comments
 (0)