Skip to content

Commit fefd829

Browse files
committed
implement __class_getitem__ from PEP 560
1 parent 36d07cf commit fefd829

File tree

5 files changed

+176
-7
lines changed

5 files changed

+176
-7
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) 2018, 2019, Oracle and/or its affiliates.
2+
# Copyright (C) 1996-2017 Python Software Foundation
3+
#
4+
# Licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
5+
import unittest
6+
7+
8+
class TestClassGetitem(unittest.TestCase):
9+
def test_class_getitem(self):
10+
getitem_args = []
11+
class C:
12+
def __class_getitem__(*args, **kwargs):
13+
getitem_args.extend([args, kwargs])
14+
return None
15+
C[int, str]
16+
self.assertEqual(getitem_args[0], (C, (int, str)))
17+
self.assertEqual(getitem_args[1], {})
18+
19+
def test_class_getitem(self):
20+
class C:
21+
def __class_getitem__(cls, item):
22+
return 'C[{0}]'.format(item.__name__)
23+
self.assertEqual(C[int], 'C[int]')
24+
self.assertEqual(C[C], 'C[C]')
25+
26+
def test_class_getitem_inheritance(self):
27+
class C:
28+
def __class_getitem__(cls, item):
29+
return '{0}[{1}]'.format(cls.__name__, item.__name__)
30+
class D(C): ...
31+
self.assertEqual(D[int], 'D[int]')
32+
self.assertEqual(D[D], 'D[D]')
33+
34+
def test_class_getitem_inheritance_2(self):
35+
class C:
36+
def __class_getitem__(cls, item):
37+
return 'Should not see this'
38+
class D(C):
39+
def __class_getitem__(cls, item):
40+
return '{0}[{1}]'.format(cls.__name__, item.__name__)
41+
self.assertEqual(D[int], 'D[int]')
42+
self.assertEqual(D[D], 'D[D]')
43+
44+
def test_class_getitem_classmethod(self):
45+
class C:
46+
@classmethod
47+
def __class_getitem__(cls, item):
48+
return '{0}[{1}]'.format(cls.__name__, item.__name__)
49+
class D(C): ...
50+
self.assertEqual(D[int], 'D[int]')
51+
self.assertEqual(D[D], 'D[D]')
52+
53+
# def test_class_getitem_patched(self):
54+
# class C:
55+
# def __init_subclass__(cls):
56+
# def __class_getitem__(cls, item):
57+
# return '{0}[{1}]'.format(cls.__name__, item.__name__)
58+
# cls.__class_getitem__ = classmethod(__class_getitem__)
59+
# class D(C): ...
60+
# self.assertEqual(D[int], 'D[int]')
61+
# self.assertEqual(D[D], 'D[D]')
62+
63+
def test_class_getitem_with_builtins(self):
64+
class A(dict):
65+
called_with = None
66+
67+
def __class_getitem__(cls, item):
68+
cls.called_with = item
69+
class B(A):
70+
pass
71+
self.assertIs(B.called_with, None)
72+
B[int]
73+
self.assertIs(B.called_with, int)
74+
75+
def test_class_getitem_errors(self):
76+
class C_too_few:
77+
def __class_getitem__(cls):
78+
return None
79+
with self.assertRaises(TypeError):
80+
C_too_few[int]
81+
class C_too_many:
82+
def __class_getitem__(cls, one, two):
83+
return None
84+
with self.assertRaises(TypeError):
85+
C_too_many[int]
86+
87+
def test_class_getitem_errors_2(self):
88+
class C:
89+
def __class_getitem__(cls, item):
90+
return None
91+
with self.assertRaises(TypeError):
92+
C()[int]
93+
class E: ...
94+
e = E()
95+
e.__class_getitem__ = lambda cls, item: 'This will not work'
96+
with self.assertRaises(TypeError):
97+
e[int]
98+
class C_not_callable:
99+
__class_getitem__ = "Surprise!"
100+
with self.assertRaises(TypeError):
101+
C_not_callable[int]
102+
103+
def test_class_getitem_metaclass(self):
104+
class Meta(type):
105+
def __class_getitem__(cls, item):
106+
return '{0}[{1}]'.format(cls.__name__, item.__name__)
107+
self.assertEqual(Meta[int], 'Meta[int]')
108+
109+
def test_class_getitem_metaclass_2(self):
110+
class Meta(type):
111+
def __getitem__(cls, item):
112+
return 'from metaclass'
113+
class C(metaclass=Meta):
114+
def __class_getitem__(cls, item):
115+
return 'from __class_getitem__'
116+
self.assertEqual(C[int], 'from metaclass')

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
import com.oracle.graal.python.builtins.objects.type.TypeNodes.GetMroNode;
130130
import com.oracle.graal.python.builtins.objects.type.TypeNodes.GetNameNode;
131131
import com.oracle.graal.python.nodes.PGuards;
132+
import com.oracle.graal.python.nodes.SpecialMethodNames;
132133
import com.oracle.graal.python.nodes.attributes.GetAttributeNode.GetAnyAttributeNode;
133134
import com.oracle.graal.python.nodes.attributes.LookupAttributeInMRONode;
134135
import com.oracle.graal.python.nodes.attributes.LookupInheritedAttributeNode;
@@ -1896,6 +1897,23 @@ private Object typeMetaclass(String name, PTuple bases, PDict namespace, PythonA
18961897
Object value = entry.getValue();
18971898
if (__SLOTS__.equals(key)) {
18981899
slots = value;
1900+
} else if (SpecialMethodNames.__NEW__.equals(key)) {
1901+
// TODO: see CPython: if it's a plain function, make it a
1902+
// static function
1903+
1904+
// tfel: this requires a little bit of refactoring on our
1905+
// side that I don't want to do now
1906+
pythonClass.setAttribute(key, value);
1907+
} else if (SpecialMethodNames.__INIT_SUBCLASS__.equals(key) ||
1908+
SpecialMethodNames.__CLASS_GETITEM__.equals(key)) {
1909+
// see CPython: Special-case __init_subclass__ and
1910+
// __class_getitem__: if they are plain functions, make them
1911+
// classmethods
1912+
if (value instanceof PFunction) {
1913+
pythonClass.setAttribute(key, factory().createClassmethod(value));
1914+
} else {
1915+
pythonClass.setAttribute(key, value);
1916+
}
18991917
} else {
19001918
pythonClass.setAttribute(key, value);
19011919
}
@@ -1957,10 +1975,6 @@ private Object typeMetaclass(String name, PTuple bases, PDict namespace, PythonA
19571975
}
19581976
}
19591977

1960-
// TODO: tfel special case __new__: if it's a plain function, make it a static function
1961-
// TODO: tfel Special-case __init_subclass__: if it's a plain function, make it a
1962-
// classmethod
1963-
19641978
return pythonClass;
19651979
}
19661980

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/SpecialMethodNames.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ public abstract class SpecialMethodNames {
161161
public static final String TOBYTES = "tobytes";
162162
public static final String DECODE = "decode";
163163
public static final String __SIZEOF__ = "__sizeof__";
164+
public static final String __CLASS_GETITEM__ = "__class_getitem__";
164165

165166
public static final String RICHCMP = "__truffle_richcompare__";
166167
public static final String TRUFFLE_SOURCE = "__truffle_source__";

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/subscript/GetItemNode.java

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,31 @@
2525
*/
2626
package com.oracle.graal.python.nodes.subscript;
2727

28+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__CLASS_GETITEM__;
2829
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GETITEM__;
30+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.AttributeError;
2931
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
3032

33+
import com.oracle.graal.python.builtins.objects.type.PythonAbstractClass;
3134
import com.oracle.graal.python.nodes.PRaiseNode;
35+
import com.oracle.graal.python.nodes.attributes.GetAttributeNode;
36+
import com.oracle.graal.python.nodes.call.CallNode;
3237
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
3338
import com.oracle.graal.python.nodes.expression.BinaryOpNode;
3439
import com.oracle.graal.python.nodes.expression.ExpressionNode;
3540
import com.oracle.graal.python.nodes.frame.ReadNode;
41+
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
3642
import com.oracle.graal.python.nodes.statement.StatementNode;
43+
import com.oracle.graal.python.runtime.exception.PException;
3744
import com.oracle.truffle.api.CompilerDirectives;
45+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
3846
import com.oracle.truffle.api.dsl.Specialization;
3947
import com.oracle.truffle.api.frame.VirtualFrame;
4048
import com.oracle.truffle.api.nodes.NodeInfo;
4149

4250
@NodeInfo(shortName = __GETITEM__)
4351
public abstract class GetItemNode extends BinaryOpNode implements ReadNode {
52+
private static final String P_OBJECT_IS_NOT_SUBSCRIPTABLE = "'%p' object is not subscriptable";
4453

4554
@Child private LookupAndCallBinaryNode callGetitemNode;
4655

@@ -61,11 +70,39 @@ public Object doSpecialObject(Object primary, Object index) {
6170
if (callGetitemNode == null) {
6271
CompilerDirectives.transferToInterpreterAndInvalidate();
6372
callGetitemNode = insert(LookupAndCallBinaryNode.create(__GETITEM__, null, () -> new LookupAndCallBinaryNode.NotImplementedHandler() {
64-
@Child private PRaiseNode raiseNode = PRaiseNode.create();
73+
@CompilationFinal private IsBuiltinClassProfile isBuiltinClassProfile;
74+
@Child private PRaiseNode raiseNode;
75+
@Child private CallNode callClassGetItemNode;
76+
@Child private GetAttributeNode getClassGetItemNode;
6577

6678
@Override
67-
public Object execute(Object arg, @SuppressWarnings("unused") Object arg2) {
68-
throw raiseNode.raise(TypeError, "'%p' object is not subscriptable", arg);
79+
public Object execute(Object arg, Object arg2) {
80+
if (arg instanceof PythonAbstractClass) {
81+
if (getClassGetItemNode == null) {
82+
CompilerDirectives.transferToInterpreterAndInvalidate();
83+
getClassGetItemNode = insert(GetAttributeNode.create(__CLASS_GETITEM__));
84+
isBuiltinClassProfile = IsBuiltinClassProfile.create();
85+
}
86+
Object classGetItem = null;
87+
try {
88+
classGetItem = getClassGetItemNode.executeObject(arg);
89+
} catch (PException e) {
90+
e.expect(AttributeError, isBuiltinClassProfile);
91+
// fall through to normal error handling
92+
}
93+
if (classGetItem != null) {
94+
if (callClassGetItemNode == null) {
95+
CompilerDirectives.transferToInterpreterAndInvalidate();
96+
callClassGetItemNode = insert(CallNode.create());
97+
}
98+
return callClassGetItemNode.execute(null, classGetItem, arg2);
99+
}
100+
}
101+
if (raiseNode == null) {
102+
CompilerDirectives.transferToInterpreterAndInvalidate();
103+
raiseNode = insert(PRaiseNode.create());
104+
}
105+
throw raiseNode.raise(TypeError, P_OBJECT_IS_NOT_SUBSCRIPTABLE, arg);
69106
}
70107
}));
71108
}

mx.graalpython/copyrights/overrides

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ graalpython/com.oracle.graal.python.test/src/tests/test_getattribute-bimorphic-i
236236
graalpython/com.oracle.graal.python.test/src/tests/test_if-class-none.py,zippy.copyright
237237
graalpython/com.oracle.graal.python.test/src/tests/test_if.py,zippy.copyright
238238
graalpython/com.oracle.graal.python.test/src/tests/test_itertools.py,python.copyright
239+
graalpython/com.oracle.graal.python.test/src/tests/test_genericclass.py,python.copyright
239240
graalpython/com.oracle.graal.python.test/src/tests/test_list.py,python.copyright
240241
graalpython/com.oracle.graal.python.test/src/tests/test_builtin.py,python.copyright
241242
graalpython/com.oracle.graal.python.test/src/tests/test_mandelbrot3.py,benchmarks.copyright

0 commit comments

Comments
 (0)