Skip to content

Commit ff4e840

Browse files
committed
fix and optimise Lock and RLock acquire builtins
- enable test timeout unittest
1 parent 443a4db commit ff4e840

File tree

8 files changed

+198
-135
lines changed

8 files changed

+198
-135
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_thread.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,17 @@ def tearDown(self):
286286
support.threading_cleanup(*self._threads)
287287
support.reap_children()
288288

289+
def assertLess(self, a, b, msg=None):
290+
if not a < b:
291+
standardMsg = '%s not less than %s' % (a, b)
292+
self.fail(self._formatMessage(msg, standardMsg))
293+
294+
def assertGreaterEqual(self, a, b, msg=None):
295+
"""Just like self.assertTrue(a >= b), but with a nicer default message."""
296+
if not a >= b:
297+
standardMsg = '%s not greater than or equal to %s' % (a, b)
298+
self.fail(self._formatMessage(msg, standardMsg))
299+
289300
def assertTimeout(self, actual, expected):
290301
# The waiting and/or time.time() can be imprecise, which
291302
# is why comparing to the expected value would sometimes fail
@@ -402,32 +413,32 @@ def f():
402413
time.sleep(0.4)
403414
self.assertEqual(n, len(threading.enumerate()))
404415

405-
# def test_timeout(self):
406-
# lock = self.locktype()
407-
# # Can't set timeout if not blocking
408-
# self.assertRaises(ValueError, lock.acquire, 0, 1)
409-
# # Invalid timeout values
410-
# self.assertRaises(ValueError, lock.acquire, timeout=-100)
411-
# self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
412-
# self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1)
413-
# # TIMEOUT_MAX is ok
414-
# lock.acquire(timeout=TIMEOUT_MAX)
415-
# lock.release()
416-
# t1 = time.time()
417-
# self.assertTrue(lock.acquire(timeout=5))
418-
# t2 = time.time()
419-
# # Just a sanity test that it didn't actually wait for the timeout.
420-
# self.assertLess(t2 - t1, 5)
421-
# results = []
422-
#
423-
# def f():
424-
# t1 = time.time()
425-
# results.append(lock.acquire(timeout=0.5))
426-
# t2 = time.time()
427-
# results.append(t2 - t1)
428-
# Bunch(f, 1).wait_for_finished()
429-
# self.assertFalse(results[0])
430-
# self.assertTimeout(results[1], 0.5)
416+
def test_timeout(self):
417+
lock = self.locktype()
418+
# Can't set timeout if not blocking
419+
self.assertRaises(ValueError, lock.acquire, 0, 1)
420+
# Invalid timeout values
421+
self.assertRaises(ValueError, lock.acquire, timeout=-100)
422+
self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
423+
self.assertRaises(OverflowError, lock.acquire, timeout=thread.TIMEOUT_MAX + 1)
424+
# TIMEOUT_MAX is ok
425+
lock.acquire(timeout=thread.TIMEOUT_MAX)
426+
lock.release()
427+
t1 = time.time()
428+
self.assertTrue(lock.acquire(timeout=5))
429+
t2 = time.time()
430+
# Just a sanity test that it didn't actually wait for the timeout.
431+
self.assertLess(t2 - t1, 5)
432+
results = []
433+
434+
def f():
435+
t1 = time.time()
436+
results.append(lock.acquire(timeout=0.5))
437+
t2 = time.time()
438+
results.append(t2 - t1)
439+
Bunch(f, 1).wait_for_finished()
440+
self.assertFalse(results[0])
441+
self.assertTimeout(results[1], 0.5)
431442

432443
def test_weakref_exists(self):
433444
lock = self.locktype()

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ThreadModuleBuiltins.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
*/
4141
package com.oracle.graal.python.builtins.modules;
4242

43+
import static com.oracle.graal.python.builtins.objects.thread.AbstractPythonLock.TIMEOUT_MAX;
4344
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ValueError;
4445

4546
import java.util.List;
@@ -72,6 +73,15 @@ protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFa
7273
return ThreadModuleBuiltinsFactory.getFactories();
7374
}
7475

76+
@Builtin(name = "__truffle_get_timeout_max__", fixedNumOfPositionalArgs = 0)
77+
@GenerateNodeFactory
78+
abstract static class GetTimeoutMaxConstNode extends PythonBuiltinNode {
79+
@Specialization
80+
double getId() {
81+
return TIMEOUT_MAX;
82+
}
83+
}
84+
7585
@Builtin(name = "LockType", fixedNumOfPositionalArgs = 1, constructsClass = PythonBuiltinClassType.PLock)
7686
@GenerateNodeFactory
7787
abstract static class ConstructLockNode extends PythonUnaryBuiltinNode {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/thread/AbstractPythonLock.java

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545

4646
public abstract class AbstractPythonLock extends PythonBuiltinObject {
4747

48+
public static double TIMEOUT_MAX = 2 ^ 31;
49+
public static boolean DEFAULT_BLOCKING = true;
50+
public static double DEFAULT_TIMEOUT = -1.0;
51+
4852
AbstractPythonLock(PythonClass cls) {
4953
super(cls);
5054
}
@@ -58,34 +62,14 @@ private static long getTimeoutInMillis(double timeout) {
5862
return seconds * 1000 + milli;
5963
}
6064

61-
abstract boolean tryToAcquire();
62-
63-
abstract boolean blockUntilAcquire() throws InterruptedException;
65+
protected abstract boolean acquireNonBlocking();
6466

65-
abstract boolean acquireTimeout(long timeout) throws InterruptedException;
67+
protected abstract boolean acquireBlocking();
6668

67-
boolean acquire() {
68-
return acquire(false, -1.0);
69-
}
70-
71-
boolean acquire(boolean blocking) {
72-
return acquire(blocking, -1.0);
73-
}
69+
protected abstract boolean acquireTimeout(long timeout);
7470

75-
boolean acquire(boolean blocking, double timeout) {
76-
if (!blocking) {
77-
return tryToAcquire();
78-
} else {
79-
try {
80-
if (timeout < 0) {
81-
return blockUntilAcquire();
82-
} else {
83-
return acquireTimeout(getTimeoutInMillis(timeout));
84-
}
85-
} catch (InterruptedException e) {
86-
return false;
87-
}
88-
}
71+
protected boolean acquireTimeout(double timeout) {
72+
return acquireTimeout(getTimeoutInMillis(timeout));
8973
}
9074

9175
public abstract void release();

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/thread/LockBuiltins.java

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@
4040
*/
4141
package com.oracle.graal.python.builtins.objects.thread;
4242

43+
import static com.oracle.graal.python.builtins.objects.thread.AbstractPythonLock.DEFAULT_BLOCKING;
44+
import static com.oracle.graal.python.builtins.objects.thread.AbstractPythonLock.DEFAULT_TIMEOUT;
45+
import static com.oracle.graal.python.builtins.objects.thread.AbstractPythonLock.TIMEOUT_MAX;
4346
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ENTER__;
4447
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EXIT__;
4548
import static com.oracle.graal.python.nodes.SpecialMethodNames.__REPR__;
49+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.OverflowError;
4650
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ValueError;
4751

4852
import java.util.List;
@@ -59,11 +63,14 @@
5963
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
6064
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
6165
import com.oracle.graal.python.nodes.util.CastToDoubleNode;
66+
import com.oracle.truffle.api.CompilerDirectives;
67+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
6268
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
6369
import com.oracle.truffle.api.dsl.Cached;
6470
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
6571
import com.oracle.truffle.api.dsl.NodeFactory;
6672
import com.oracle.truffle.api.dsl.Specialization;
73+
import com.oracle.truffle.api.profiles.ConditionProfile;
6774

6875
@CoreFunctions(extendClasses = PythonBuiltinClassType.PLock)
6976
public class LockBuiltins extends PythonBuiltins {
@@ -75,68 +82,80 @@ protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFa
7582
@Builtin(name = "acquire", minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 3, keywordArguments = {"blocking", "timeout"})
7683
@GenerateNodeFactory
7784
abstract static class AcquireLockNode extends PythonTernaryBuiltinNode {
78-
@Specialization
79-
@TruffleBoundary
80-
boolean doAcquire(PLock self, @SuppressWarnings("unused") PNone waitFlag, @SuppressWarnings("unused") PNone timeout) {
81-
return self.acquire(true);
82-
}
85+
private @Child CastToDoubleNode castToDoubleNode;
86+
private @Child CastToBooleanNode castToBooleanNode;
87+
private @CompilationFinal ConditionProfile isBlockingProfile = ConditionProfile.createBinaryProfile();
88+
private @CompilationFinal ConditionProfile defaultTimeoutProfile = ConditionProfile.createBinaryProfile();
8389

84-
@Specialization
85-
@TruffleBoundary
86-
boolean doAcquire(PLock self, Object blocking, @SuppressWarnings("unused") PNone timeout,
87-
@Cached("createIfTrueNode()") CastToBooleanNode castToBooleanNode) {
88-
return self.acquire(castToBooleanNode.executeWith(blocking));
90+
private CastToDoubleNode getCastToDoubleNode() {
91+
if (castToDoubleNode == null) {
92+
CompilerDirectives.transferToInterpreterAndInvalidate();
93+
castToDoubleNode = insert(CastToDoubleNode.create());
94+
}
95+
return castToDoubleNode;
8996
}
9097

91-
@Specialization
92-
@TruffleBoundary
93-
boolean doAcquire(PLock self, @SuppressWarnings("unused") PNone waitFlag, Object timeout,
94-
@Cached("create()") CastToDoubleNode castToDoubleNode) {
95-
double timeoutSeconds = castToDoubleNode.execute(timeout);
96-
if (timeoutSeconds < 0) {
97-
throw raise(ValueError, "timeout value must be positive");
98+
private CastToBooleanNode getCastToBooleanNode() {
99+
if (castToBooleanNode == null) {
100+
CompilerDirectives.transferToInterpreterAndInvalidate();
101+
castToBooleanNode = insert(CastToBooleanNode.createIfTrueNode());
98102
}
99-
return self.acquire(true, timeoutSeconds);
103+
return castToBooleanNode;
100104
}
101105

102106
@Specialization
103-
@TruffleBoundary
104-
boolean doAcquire(PLock self, Object blocking, Object timeout,
105-
@Cached("create()") CastToDoubleNode castToDoubleNode,
106-
@Cached("createIfTrueNode()") CastToBooleanNode castToBooleanNode) {
107-
boolean isBlocking = castToBooleanNode.executeWith(blocking);
108-
if (!isBlocking) {
109-
throw raise(ValueError, "can't specify a timeout for a non-blocking call");
107+
boolean doAcquire(PLock self, Object blocking, Object timeout) {
108+
// args setup
109+
boolean isBlocking = (blocking instanceof PNone) ? DEFAULT_BLOCKING : getCastToBooleanNode().executeWith(blocking);
110+
double timeoutSeconds = DEFAULT_TIMEOUT;
111+
if (!(timeout instanceof PNone)) {
112+
if (!isBlocking) {
113+
throw raise(ValueError, "can't specify a timeout for a non-blocking call");
114+
}
115+
116+
timeoutSeconds = getCastToDoubleNode().execute(timeout);
117+
118+
if (timeoutSeconds < 0) {
119+
throw raise(ValueError, "timeout value must be positive");
120+
} else if (timeoutSeconds > TIMEOUT_MAX) {
121+
throw raise(OverflowError, "timeout value is too large");
122+
}
110123
}
111-
double timeoutSeconds = castToDoubleNode.execute(timeout);
112-
if (timeoutSeconds < 0) {
113-
throw raise(ValueError, "timeout value must be positive");
124+
125+
// acquire lock
126+
if (isBlockingProfile.profile(!isBlocking)) {
127+
return self.acquireNonBlocking();
128+
} else {
129+
if (defaultTimeoutProfile.profile(timeoutSeconds == DEFAULT_TIMEOUT)) {
130+
return self.acquireBlocking();
131+
} else {
132+
return self.acquireTimeout(timeoutSeconds);
133+
}
114134
}
115-
return self.acquire(true, timeoutSeconds);
116135
}
117136

118137
public static AcquireLockNode create() {
119138
return AcquireLockNodeFactory.create();
120139
}
121140
}
122141

123-
@Builtin(name = "acquire_lock", minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 3, keywordArguments = {"waitflag", "timeout"})
142+
@Builtin(name = "acquire_lock", minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 3, keywordArguments = {"blocking", "timeout"})
124143
@GenerateNodeFactory
125144
abstract static class AcquireLockLockNode extends PythonTernaryBuiltinNode {
126145
@Specialization
127-
Object acquire(PLock self, Object waitFlag, Object timeout,
146+
Object acquire(PLock self, Object blocking, Object timeout,
128147
@Cached("create()") AcquireLockNode acquireLockNode) {
129-
return acquireLockNode.execute(self, waitFlag, timeout);
148+
return acquireLockNode.execute(self, blocking, timeout);
130149
}
131150
}
132151

133-
@Builtin(name = __ENTER__, minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 3, keywordArguments = {"waitflag", "timeout"})
152+
@Builtin(name = __ENTER__, minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 3, keywordArguments = {"blocking", "timeout"})
134153
@GenerateNodeFactory
135154
abstract static class EnterLockNode extends PythonTernaryBuiltinNode {
136155
@Specialization
137-
Object acquire(PLock self, Object waitFlag, Object timeout,
156+
Object acquire(PLock self, Object blocking, Object timeout,
138157
@Cached("create()") AcquireLockNode acquireLockNode) {
139-
return acquireLockNode.execute(self, waitFlag, timeout);
158+
return acquireLockNode.execute(self, blocking, timeout);
140159
}
141160
}
142161

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/thread/PLock.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
import java.util.concurrent.TimeUnit;
4545

4646
import com.oracle.graal.python.builtins.objects.type.PythonClass;
47-
48-
public class PLock extends AbstractPythonLock {
47+
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
48+
public final class PLock extends AbstractPythonLock {
4949
private final Semaphore semaphore;
5050

5151
public PLock(PythonClass cls) {
@@ -54,22 +54,34 @@ public PLock(PythonClass cls) {
5454
}
5555

5656
@Override
57-
boolean tryToAcquire() {
57+
@TruffleBoundary
58+
protected boolean acquireNonBlocking() {
5859
return semaphore.tryAcquire();
5960
}
6061

6162
@Override
62-
boolean blockUntilAcquire() throws InterruptedException {
63-
semaphore.acquire();
64-
return true;
63+
@TruffleBoundary
64+
protected boolean acquireBlocking() {
65+
try {
66+
semaphore.acquire();
67+
return true;
68+
} catch (InterruptedException e) {
69+
return false;
70+
}
6571
}
6672

6773
@Override
68-
boolean acquireTimeout(long timeout) throws InterruptedException {
69-
return semaphore.tryAcquire(timeout, TimeUnit.MILLISECONDS);
74+
@TruffleBoundary
75+
protected boolean acquireTimeout(long timeout) {
76+
try {
77+
return semaphore.tryAcquire(timeout, TimeUnit.MILLISECONDS);
78+
} catch (InterruptedException e) {
79+
return false;
80+
}
7081
}
7182

7283
@Override
84+
@TruffleBoundary
7385
public void release() {
7486
semaphore.release();
7587
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/thread/PRLock.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@
4444
import java.util.concurrent.locks.ReentrantLock;
4545

4646
import com.oracle.graal.python.builtins.objects.type.PythonClass;
47+
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
4748

48-
public class PRLock extends AbstractPythonLock {
49+
public final class PRLock extends AbstractPythonLock {
4950
private class InternalReentrantLock extends ReentrantLock {
5051
private static final long serialVersionUID = 2531000884985514112L;
5152

@@ -84,22 +85,30 @@ public void releaseAll() {
8485
}
8586

8687
@Override
87-
boolean tryToAcquire() {
88+
@TruffleBoundary
89+
protected boolean acquireNonBlocking() {
8890
return lock.tryLock();
8991
}
9092

9193
@Override
92-
boolean blockUntilAcquire() throws InterruptedException {
94+
@TruffleBoundary
95+
protected boolean acquireBlocking() {
9396
lock.lock();
9497
return true;
9598
}
9699

97100
@Override
98-
boolean acquireTimeout(long timeout) throws InterruptedException {
99-
return lock.tryLock(timeout, TimeUnit.MILLISECONDS);
101+
@TruffleBoundary
102+
protected boolean acquireTimeout(long timeout) {
103+
try {
104+
return lock.tryLock(timeout, TimeUnit.MILLISECONDS);
105+
} catch (InterruptedException e) {
106+
return false;
107+
}
100108
}
101109

102110
@Override
111+
@TruffleBoundary
103112
public void release() {
104113
lock.unlock();
105114
}

0 commit comments

Comments
 (0)