Skip to content

Commit b914212

Browse files
committed
[GR-46628] Reuse weakref if no callback is specified
PullRequest: graalpython/2834
2 parents d993fae + 27c9571 commit b914212

File tree

3 files changed

+255
-5
lines changed

3 files changed

+255
-5
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_weakref.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -107,3 +107,19 @@ def do_hash(item):
107107
assert False, "could compute hash for r3 but should have failed"
108108

109109
assert r1_hash == do_hash(r1)
110+
111+
def test_weakref_reuse():
112+
from weakref import ref
113+
class A:
114+
pass
115+
a = A()
116+
assert ref(a) == ref(a)
117+
118+
def test_weakref_object_type_support():
119+
from weakref import ref
120+
try:
121+
ref(42)
122+
except TypeError as e:
123+
pass
124+
else:
125+
assert False, "should throw TypeError for unsupported objects"

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/PythonBuiltinClassType.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
import java.util.HashSet;
100100

101101
import com.oracle.graal.python.PythonLanguage;
102+
import com.oracle.graal.python.builtins.modules.WeakRefModuleBuiltins;
102103
import com.oracle.graal.python.builtins.objects.function.BuiltinMethodDescriptor;
103104
import com.oracle.graal.python.builtins.objects.type.SpecialMethodSlot;
104105
import com.oracle.graal.python.runtime.PythonContext;
@@ -502,6 +503,7 @@ private static class Flags {
502503
// initialized in static constructor
503504
@CompilationFinal private PythonBuiltinClassType type;
504505
@CompilationFinal private PythonBuiltinClassType base;
506+
@CompilationFinal private int weaklistoffset;
505507

506508
/**
507509
* @see #redefinesSlot(SpecialMethodSlot)
@@ -543,6 +545,7 @@ private static class Flags {
543545
this.isBuiltinWithDict = flags.isBuiltinWithDict;
544546
this.isException = flags == Flags.EXCEPTION;
545547
this.methodsFlags = methodsFlags;
548+
this.weaklistoffset = -1;
546549
}
547550

548551
PythonBuiltinClassType(String name, String module) {
@@ -613,6 +616,10 @@ public long getMethodsFlags() {
613616
return methodsFlags;
614617
}
615618

619+
public int getWeaklistoffset() {
620+
return weaklistoffset;
621+
}
622+
616623
/**
617624
* Returns {@code true} if this method slot is redefined in Python code during initialization.
618625
* Values of such slots cannot be cached in {@link #specialMethodSlots}, because they are not
@@ -863,6 +870,8 @@ public final Shape getInstanceShape(PythonLanguage lang) {
863870
if (type.type == null && type.base != null) {
864871
type.type = type.base.type;
865872
}
873+
874+
type.weaklistoffset = WeakRefModuleBuiltins.getBuiltinTypeWeaklistoffset(type);
866875
}
867876

868877
// Finally, we set all remaining types to PythonClass.

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/WeakRefModuleBuiltins.java

Lines changed: 229 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import static com.oracle.graal.python.nodes.BuiltinNames.J__WEAKREF;
4444
import static com.oracle.graal.python.nodes.BuiltinNames.T__WEAKREF;
4545
import static com.oracle.graal.python.nodes.StringLiterals.T_REF;
46+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
4647

4748
import java.lang.ref.Reference;
4849
import java.lang.ref.ReferenceQueue;
@@ -70,6 +71,7 @@
7071
import com.oracle.graal.python.nodes.PGuards;
7172
import com.oracle.graal.python.nodes.WriteUnraisableNode;
7273
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
74+
import com.oracle.graal.python.nodes.attributes.WriteAttributeToDynamicObjectNode;
7375
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
7476
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
7577
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
@@ -79,7 +81,6 @@
7981
import com.oracle.graal.python.runtime.PythonContext;
8082
import com.oracle.graal.python.runtime.PythonOptions;
8183
import com.oracle.graal.python.runtime.exception.PException;
82-
import com.oracle.graal.python.runtime.exception.PythonErrorType;
8384
import com.oracle.truffle.api.CompilerDirectives;
8485
import com.oracle.truffle.api.dsl.Bind;
8586
import com.oracle.truffle.api.dsl.Cached;
@@ -95,11 +96,212 @@ public class WeakRefModuleBuiltins extends PythonBuiltins {
9596
private static final HiddenKey weakRefQueueKey = new HiddenKey("weakRefQueue");
9697
private final ReferenceQueue<Object> weakRefQueue = new ReferenceQueue<>();
9798

99+
// This GraalPy specific as CPython is storing weakref list within a PyObject (obj +
100+
// tp_weaklistoffset)
101+
public static final HiddenKey __WEAKLIST__ = new HiddenKey("__weaklist__");
102+
98103
@Override
99104
protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFactories() {
100105
return WeakRefModuleBuiltinsFactory.getFactories();
101106
}
102107

108+
public static int getBuiltinTypeWeaklistoffset(PythonBuiltinClassType cls) {
109+
// @formatter:off
110+
return switch (cls) {
111+
case PythonObject, // object
112+
PInt, // int
113+
Boolean, // bool
114+
PByteArray, // bytearray
115+
PBytes, // bytes
116+
PList, // list
117+
PNone, // NoneType
118+
PNotImplemented, // NotImplementedType
119+
PTraceback, // traceback
120+
Super, // super
121+
PRange, // range
122+
PDict, // dict
123+
PDictKeysView, // dict_keys
124+
PDictValuesView, // dict_values
125+
PDictItemsView, // dict_items
126+
PDictReverseKeyIterator, // dict_reversekeyiterator
127+
PDictReverseValueIterator, // dict_reversevalueiterator
128+
PDictReverseItemIterator, // dict_reverseitemiterator
129+
PString, // str
130+
PSlice, // slice
131+
PStaticmethod, // staticmethod
132+
PComplex, // complex
133+
PFloat, // float
134+
PProperty, // property
135+
PTuple, // tuple
136+
PEnumerate, // enumerate
137+
PReverseIterator, // reversed
138+
PFrame, // frame
139+
PMappingproxy, // mappingproxy
140+
GetSetDescriptor, // getset_descriptor
141+
WrapperDescriptor, // wrapper_descriptor
142+
MethodWrapper, // method-wrapper
143+
PEllipsis, // ellipsis
144+
MemberDescriptor, // member_descriptor
145+
PSimpleNamespace, // types.SimpleNamespace
146+
Capsule, // PyCapsule
147+
PCell, // cell
148+
PInstancemethod, // instancemethod
149+
PBuiltinClassMethod, // classmethod_descriptor
150+
PBuiltinFunction, // method_descriptor
151+
PSentinelIterator, // callable_iterator
152+
PIterator, // iterator
153+
PCoroutineWrapper, // coroutine_wrapper
154+
PEncodingMap, // EncodingMap
155+
PIntInfo, // sys.int_info
156+
PBaseException, // BaseException
157+
Exception, // Exception
158+
TypeError, // TypeError
159+
StopAsyncIteration, // StopAsyncIteration
160+
StopIteration, // StopIteration
161+
GeneratorExit, // GeneratorExit
162+
SystemExit, // SystemExit
163+
KeyboardInterrupt, // KeyboardInterrupt
164+
ImportError, // ImportError
165+
ModuleNotFoundError, // ModuleNotFoundError
166+
OSError, // OSError
167+
EOFError, // EOFError
168+
RuntimeError, // RuntimeError
169+
RecursionError, // RecursionError
170+
NotImplementedError, // NotImplementedError
171+
NameError, // NameError
172+
UnboundLocalError, // UnboundLocalError
173+
AttributeError, // AttributeError
174+
SyntaxError, // SyntaxError
175+
IndentationError, // IndentationError
176+
TabError, // TabError
177+
LookupError, // LookupError
178+
IndexError, // IndexError
179+
KeyError, // KeyError
180+
ValueError, // ValueError
181+
UnicodeError, // UnicodeError
182+
UnicodeEncodeError, // UnicodeEncodeError
183+
UnicodeDecodeError, // UnicodeDecodeError
184+
UnicodeTranslateError, // UnicodeTranslateError
185+
AssertionError, // AssertionError
186+
ArithmeticError, // ArithmeticError
187+
FloatingPointError, // FloatingPointError
188+
OverflowError, // OverflowError
189+
ZeroDivisionError, // ZeroDivisionError
190+
SystemError, // SystemError
191+
ReferenceError, // ReferenceError
192+
MemoryError, // MemoryError
193+
BufferError, // BufferError
194+
Warning, // Warning
195+
UserWarning, // UserWarning
196+
DeprecationWarning, // DeprecationWarning
197+
PendingDeprecationWarning, // PendingDeprecationWarning
198+
SyntaxWarning, // SyntaxWarning
199+
RuntimeWarning, // RuntimeWarning
200+
FutureWarning, // FutureWarning
201+
ImportWarning, // ImportWarning
202+
UnicodeWarning, // UnicodeWarning
203+
BytesWarning, // BytesWarning
204+
ResourceWarning, // ResourceWarning
205+
ConnectionError, // ConnectionError
206+
BlockingIOError, // BlockingIOError
207+
BrokenPipeError, // BrokenPipeError
208+
ChildProcessError, // ChildProcessError
209+
ConnectionAbortedError, // ConnectionAbortedError
210+
ConnectionRefusedError, // ConnectionRefusedError
211+
ConnectionResetError, // ConnectionResetError
212+
FileExistsError, // FileExistsError
213+
FileNotFoundError, // FileNotFoundError
214+
IsADirectoryError, // IsADirectoryError
215+
NotADirectoryError, // NotADirectoryError
216+
InterruptedError, // InterruptedError
217+
PermissionError, // PermissionError
218+
ProcessLookupError, // ProcessLookupError
219+
TimeoutError, // TimeoutError
220+
PFloatInfo, // sys.float_info
221+
PythonModuleDef, // moduledef
222+
PHashInfo, // sys.hash_info
223+
PVersionInfo, // sys.version_info
224+
PFlags, // sys.flags
225+
PThreadInfo, // sys.thread_info
226+
PMap, // map
227+
PZip, // zip
228+
PClassmethod, // classmethod
229+
PBytesIOBuf, // _io._BytesIOBuffer
230+
PIncrementalNewlineDecoder, // _io.IncrementalNewlineDecoder
231+
PStatResult, // os.stat_result
232+
PStatvfsResult, // os.statvfs_result
233+
PTerminalSize, // os.terminal_size
234+
PScandirIterator, // posix.ScandirIterator
235+
PDirEntry, // posix.DirEntry
236+
PUnameResult, // posix.uname_result
237+
PStructTime, // time.struct_time
238+
PDictItemIterator, // dict_itemiterator
239+
PDictKeyIterator, // dict_keyiterator
240+
PDictValueIterator, // dict_valueiterator
241+
PAccumulate, // itertools.accumulate
242+
PCombinations, // itertools.combinations
243+
PCombinationsWithReplacement, // itertools.combinations_with_replacement
244+
PCycle, // itertools.cycle
245+
PDropwhile, // itertools.dropwhile
246+
PTakewhile, // itertools.takewhile
247+
PIslice, // itertools.islice
248+
PStarmap, // itertools.starmap
249+
PChain, // itertools.chain
250+
PCompress, // itertools.compress
251+
PFilterfalse, // itertools.filterfalse
252+
PCount, // itertools.count
253+
PZipLongest, // itertools.zip_longest
254+
PPermutations, // itertools.permutations
255+
PProduct, // itertools.product
256+
PRepeat, // itertools.repeat
257+
PGroupBy, // itertools.groupby
258+
PTeeDataObject, // itertools._tee_dataobject
259+
PDefaultDict, // collections.defaultdict
260+
PDequeIter, // _collections._deque_iterator
261+
PDequeRevIter, // _collections._deque_reverse_iterator
262+
PTupleGetter // _collections._tuplegetter
263+
-> 0;
264+
case PythonClass -> 368; // type
265+
case PSet, // set
266+
PFrozenSet // frozenset
267+
-> 192;
268+
case PMemoryView, // memoryview
269+
PCode // code
270+
-> 136;
271+
case PBuiltinFunctionOrMethod, // builtin_function_or_method
272+
PGenerator, // generator
273+
PCoroutine, // coroutine
274+
PythonModule, // module
275+
PThreadLocal, // _thread._local
276+
PRLock, // _thread.RLock
277+
PBufferedRWPair, // _io.BufferedRWPair
278+
PAsyncGenerator // async_generator
279+
-> 40;
280+
case PMethod, // method
281+
PFileIO, // _io.FileIO
282+
PTee // itertools._tee
283+
-> 32;
284+
case PFunction -> 80; // function
285+
case PIOBase, // _io._IOBase
286+
PRawIOBase, // _io._RawIOBase
287+
PBufferedIOBase, // _io._BufferedIOBase
288+
PTextIOBase // _io._TextIOBase
289+
-> 24;
290+
case PBytesIO, // _io.BytesIO
291+
PPartial // functools.partial
292+
-> 48;
293+
case PStringIO -> 112; // _io.StringIO
294+
case PBufferedReader, // _io.BufferedReader
295+
PBufferedWriter, // _io.BufferedWriter
296+
PBufferedRandom // _io.BufferedRandom
297+
-> 144;
298+
case PTextIOWrapper -> 176; // _io.TextIOWrapper
299+
300+
default -> -1; // unknown or not implemented
301+
// @formatter:on
302+
};
303+
}
304+
103305
private static class WeakrefCallbackAction extends AsyncHandler.AsyncPythonAction {
104306
private final WeakRefStorage[] references;
105307
private int index;
@@ -179,8 +381,31 @@ public abstract static class ReferenceTypeNode extends PythonTernaryBuiltinNode
179381
@Child private CExtNodes.GetTypeMemberNode getTpWeaklistoffsetNode;
180382

181383
@Specialization(guards = "!isNativeObject(object)")
182-
public PReferenceType refType(Object cls, Object object, @SuppressWarnings("unused") PNone none) {
183-
return factory().createReferenceType(cls, object, null, getWeakReferenceQueue());
384+
public PReferenceType refType(Object cls, Object object, @SuppressWarnings("unused") PNone none,
385+
@Cached InlinedGetClassNode getClassNode,
386+
@Cached ReadAttributeFromObjectNode getAttrNode,
387+
@Cached WriteAttributeToDynamicObjectNode setAttrNode) {
388+
Object obj = object;
389+
if (object instanceof PythonBuiltinClassType tobj) {
390+
obj = getContext().getCore().lookupType(tobj);
391+
}
392+
393+
Object clazz = getClassNode.execute(this, obj);
394+
boolean allowed = true;
395+
if (clazz instanceof PythonBuiltinClassType type) {
396+
allowed = type.getWeaklistoffset() != 0;
397+
}
398+
if (!allowed) {
399+
throw raise(TypeError, ErrorMessages.CANNOT_CREATE_WEAK_REFERENCE_TO, obj);
400+
}
401+
Object wr = getAttrNode.execute(obj, __WEAKLIST__);
402+
if (wr != PNone.NO_VALUE) {
403+
return (PReferenceType) wr; // is must be a PReferenceType instance.
404+
}
405+
406+
PReferenceType ref = factory().createReferenceType(cls, obj, null, getWeakReferenceQueue());
407+
setAttrNode.execute(obj, __WEAKLIST__, ref);
408+
return ref;
184409
}
185410

186411
@Specialization(guards = {"!isNativeObject(object)", "!isPNone(callback)"})
@@ -223,7 +448,7 @@ public PReferenceType refType(Object cls, PythonAbstractNativeObject pythonObjec
223448

224449
@Fallback
225450
public PReferenceType refType(@SuppressWarnings("unused") Object cls, Object object, @SuppressWarnings("unused") Object callback) {
226-
throw raise(PythonErrorType.TypeError, ErrorMessages.CANNOT_CREATE_WEAK_REFERENCE_TO, object);
451+
throw raise(TypeError, ErrorMessages.CANNOT_CREATE_WEAK_REFERENCE_TO, object);
227452
}
228453

229454
@SuppressWarnings("unchecked")

0 commit comments

Comments
 (0)