@@ -76,72 +76,166 @@ class C:
7676 new_version = type_get_version (C )
7777 self .assertEqual (new_version , orig_version + 5 )
7878
79- def test_specialization_user_type_no_tag_overflow (self ):
79+ _clear_type_cache ()
80+
81+
82+ @support .cpython_only
83+ class TypeCacheWithSpecializationTests (unittest .TestCase ):
84+ def tearDown (self ):
85+ _clear_type_cache ()
86+
87+ def _assign_and_check_valid_version (self , user_type ):
88+ type_modified (user_type )
89+ type_assign_version (user_type )
90+ self .assertNotEqual (type_get_version (user_type ), 0 )
91+
92+ def _assign_and_check_version_0 (self , user_type ):
93+ type_modified (user_type )
94+ type_assign_specific_version_unsafe (user_type , 0 )
95+ self .assertEqual (type_get_version (user_type ), 0 )
96+
97+ def _all_opnames (self , func ):
98+ return set (instr .opname for instr in dis .Bytecode (func , adaptive = True ))
99+
100+ def _check_specialization (self , func , arg , opname , * , should_specialize ):
101+ self .assertIn (opname , self ._all_opnames (func ))
102+
103+ for _ in range (100 ):
104+ func (arg )
105+
106+ if should_specialize :
107+ self .assertNotIn (opname , self ._all_opnames (func ))
108+ else :
109+ self .assertIn (opname , self ._all_opnames (func ))
110+
111+ def test_class_load_attr_specialization_user_type (self ):
80112 class A :
81113 def foo (self ):
82114 pass
83115
84- class B :
85- def foo (self ):
86- pass
116+ self ._assign_and_check_valid_version (A )
117+
118+ def load_foo_1 (type_ ):
119+ type_ .foo
120+
121+ self ._check_specialization (load_foo_1 , A , "LOAD_ATTR" , should_specialize = True )
122+ del load_foo_1
87123
88- type_modified (A )
89- type_assign_version (A )
90- type_modified (B )
91- type_assign_version (B )
92- self .assertNotEqual (type_get_version (A ), 0 )
93- self .assertNotEqual (type_get_version (B ), 0 )
94- self .assertNotEqual (type_get_version (A ), type_get_version (B ))
124+ self ._assign_and_check_version_0 (A )
95125
96- def get_foo (type_ ):
126+ def load_foo_2 (type_ ):
97127 return type_ .foo
98128
99- self .assertIn (
100- "LOAD_ATTR" ,
101- [instr .opname for instr in dis .Bytecode (get_foo , adaptive = True )],
102- )
129+ self ._check_specialization (load_foo_2 , A , "LOAD_ATTR" , should_specialize = False )
103130
104- get_foo (A )
105- get_foo (A )
131+ def test_class_load_attr_specialization_static_type (self ):
132+ self ._assign_and_check_valid_version (str )
133+ self ._assign_and_check_valid_version (bytes )
106134
107- # check that specialization has occurred
108- self .assertNotIn (
109- "LOAD_ATTR" ,
110- [instr .opname for instr in dis .Bytecode (get_foo , adaptive = True )],
111- )
135+ def get_capitalize_1 (type_ ):
136+ return type_ .capitalize
112137
113- def test_specialization_user_type_tag_overflow (self ):
114- class A :
115- def foo (self ):
116- pass
138+ self ._check_specialization (get_capitalize_1 , str , "LOAD_ATTR" , should_specialize = True )
139+ self .assertEqual (get_capitalize_1 (str )('hello' ), 'Hello' )
140+ self .assertEqual (get_capitalize_1 (bytes )(b'hello' ), b'Hello' )
141+ del get_capitalize_1
142+
143+ # Permanently overflow the static type version counter, and force str and bytes
144+ # to have tp_version_tag == 0
145+ for _ in range (2 ** 16 ):
146+ type_modified (str )
147+ type_assign_version (str )
148+ type_modified (bytes )
149+ type_assign_version (bytes )
150+
151+ self .assertEqual (type_get_version (str ), 0 )
152+ self .assertEqual (type_get_version (bytes ), 0 )
153+
154+ def get_capitalize_2 (type_ ):
155+ return type_ .capitalize
156+
157+ self ._check_specialization (get_capitalize_2 , str , "LOAD_ATTR" , should_specialize = False )
158+ self .assertEqual (get_capitalize_2 (str )('hello' ), 'Hello' )
159+ self .assertEqual (get_capitalize_2 (bytes )(b'hello' ), b'Hello' )
160+
161+ def test_property_load_attr_specialization_user_type (self ):
162+ class G :
163+ @property
164+ def x (self ):
165+ return 9
166+
167+ self ._assign_and_check_valid_version (G )
117168
169+ def load_x_1 (instance ):
170+ instance .x
171+
172+ self ._check_specialization (load_x_1 , G (), "LOAD_ATTR" , should_specialize = True )
173+ del load_x_1
174+
175+ self ._assign_and_check_version_0 (G )
176+
177+ def load_x_2 (instance ):
178+ instance .x
179+
180+ self ._check_specialization (load_x_2 , G (), "LOAD_ATTR" , should_specialize = False )
181+
182+ def test_store_attr_specialization_user_type (self ):
118183 class B :
119- def foo (self ):
184+ __slots__ = ("bar" ,)
185+
186+ self ._assign_and_check_valid_version (B )
187+
188+ def store_bar_1 (type_ ):
189+ type_ .bar = 10
190+
191+ self ._check_specialization (store_bar_1 , B (), "STORE_ATTR" , should_specialize = True )
192+ del store_bar_1
193+
194+ self ._assign_and_check_version_0 (B )
195+
196+ def store_bar_2 (type_ ):
197+ type_ .bar = 10
198+
199+ self ._check_specialization (store_bar_2 , B (), "STORE_ATTR" , should_specialize = False )
200+
201+ def test_class_call_specialization_user_type (self ):
202+ class F :
203+ def __init__ (self ):
120204 pass
121205
122- type_modified (A )
123- type_assign_specific_version_unsafe (A , 0 )
124- type_modified (B )
125- type_assign_specific_version_unsafe (B , 0 )
126- self .assertEqual (type_get_version (A ), 0 )
127- self .assertEqual (type_get_version (B ), 0 )
206+ self ._assign_and_check_valid_version (F )
128207
129- def get_foo (type_ ):
130- return type_ .foo
208+ def call_class_1 (type_ ):
209+ type_ ()
210+
211+ self ._check_specialization (call_class_1 , F , "CALL" , should_specialize = True )
212+ del call_class_1
213+
214+ self ._assign_and_check_version_0 (F )
215+
216+ def call_class_2 (type_ ):
217+ type_ ()
218+
219+ self ._check_specialization (call_class_2 , F , "CALL" , should_specialize = False )
220+
221+ def test_to_bool_specialization_user_type (self ):
222+ class H :
223+ pass
224+
225+ self ._assign_and_check_valid_version (H )
226+
227+ def to_bool_1 (instance ):
228+ not instance
229+
230+ self ._check_specialization (to_bool_1 , H (), "TO_BOOL" , should_specialize = True )
231+ del to_bool_1
131232
132- self .assertIn (
133- "LOAD_ATTR" ,
134- [instr .opname for instr in dis .Bytecode (get_foo , adaptive = True )],
135- )
233+ self ._assign_and_check_version_0 (H )
136234
137- get_foo ( A )
138- get_foo ( A )
235+ def to_bool_2 ( instance ):
236+ not instance
139237
140- # check that specialization has not occurred due to version tag == 0
141- self .assertIn (
142- "LOAD_ATTR" ,
143- [instr .opname for instr in dis .Bytecode (get_foo , adaptive = True )],
144- )
238+ self ._check_specialization (to_bool_2 , H (), "TO_BOOL" , should_specialize = False )
145239
146240
147241if __name__ == "__main__" :
0 commit comments