|
8 | 8 | import itertools |
9 | 9 | import pickle |
10 | 10 | from string.templatelib import Template |
11 | | -import types |
12 | 11 | import typing |
13 | 12 | import unittest |
14 | | -import unittest.mock |
15 | 13 | from annotationlib import ( |
16 | 14 | Format, |
17 | 15 | ForwardRef, |
@@ -1209,89 +1207,145 @@ def evaluate(format, exc=NotImplementedError): |
1209 | 1207 |
|
1210 | 1208 |
|
1211 | 1209 | class TestCallAnnotateFunction(unittest.TestCase): |
1212 | | - def _annotate_mock(self): |
| 1210 | + # Tests for user defined annotate functions. |
| 1211 | + |
| 1212 | + # Format and NotImplementedError are provided as arguments so they exist in |
| 1213 | + # the fake globals namespace. |
| 1214 | + # This avoids non-matching conditions passing by being converted to stringifiers. |
| 1215 | + # See: https://github.com/python/cpython/issues/138764 |
| 1216 | + |
| 1217 | + def test_user_annotate_value(self): |
1213 | 1218 | def annotate(format, /): |
1214 | 1219 | if format == Format.VALUE: |
1215 | 1220 | return {"x": str} |
1216 | 1221 | else: |
1217 | 1222 | raise NotImplementedError(format) |
1218 | 1223 |
|
1219 | | - annotate_mock = unittest.mock.MagicMock( |
1220 | | - wraps=annotate |
| 1224 | + annotations = annotationlib.call_annotate_function( |
| 1225 | + annotate, |
| 1226 | + Format.VALUE, |
1221 | 1227 | ) |
1222 | 1228 |
|
1223 | | - # Add missing magic attributes needed |
1224 | | - required_magic = [ |
1225 | | - "__builtins__", |
1226 | | - "__closure__", |
1227 | | - "__code__", |
1228 | | - "__defaults__", |
1229 | | - "__globals__", |
1230 | | - "__kwdefaults__", |
1231 | | - ] |
| 1229 | + self.assertEqual(annotations, {"x": str}) |
| 1230 | + |
| 1231 | + def test_user_annotate_forwardref_supported(self): |
| 1232 | + # If Format.FORWARDREF is supported prefer it over Format.VALUE |
| 1233 | + def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedError): |
| 1234 | + if format == __Format.VALUE: |
| 1235 | + return {'x': str} |
| 1236 | + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: |
| 1237 | + return {'x': int} |
| 1238 | + elif format == __Format.FORWARDREF: |
| 1239 | + return {'x': float} |
| 1240 | + else: |
| 1241 | + raise __NotImplementedError(format) |
1232 | 1242 |
|
1233 | | - for attrib in required_magic: |
1234 | | - setattr(annotate_mock, attrib, getattr(annotate, attrib)) |
| 1243 | + annotations = annotationlib.call_annotate_function( |
| 1244 | + annotate, |
| 1245 | + Format.FORWARDREF |
| 1246 | + ) |
1235 | 1247 |
|
1236 | | - return annotate_mock |
| 1248 | + self.assertEqual(annotations, {"x": float}) |
1237 | 1249 |
|
1238 | | - def test_user_annotate_value(self): |
1239 | | - annotate = self._annotate_mock() |
| 1250 | + def test_user_annotate_forwardref_fakeglobals(self): |
| 1251 | + # If Format.FORWARDREF is not supported, use Format.VALUE_WITH_FAKE_GLOBALS |
| 1252 | + # before falling back to Format.VALUE |
| 1253 | + def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedError): |
| 1254 | + if format == __Format.VALUE: |
| 1255 | + return {'x': str} |
| 1256 | + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: |
| 1257 | + return {'x': int} |
| 1258 | + else: |
| 1259 | + raise __NotImplementedError(format) |
1240 | 1260 |
|
1241 | 1261 | annotations = annotationlib.call_annotate_function( |
1242 | 1262 | annotate, |
1243 | | - Format.VALUE, |
| 1263 | + Format.FORWARDREF |
1244 | 1264 | ) |
1245 | 1265 |
|
1246 | | - self.assertEqual(annotations, {"x": str}) |
1247 | | - annotate.assert_called_once_with(Format.VALUE) |
| 1266 | + self.assertEqual(annotations, {"x": int}) |
1248 | 1267 |
|
1249 | | - def test_user_annotate_forwardref(self): |
1250 | | - annotate = self._annotate_mock() |
| 1268 | + def test_user_annotate_forwardref_value_fallback(self): |
| 1269 | + # If Format.FORWARDREF and Format.VALUE_WITH_FAKE_GLOBALS are not supported |
| 1270 | + # use Format.VALUE |
| 1271 | + def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedError): |
| 1272 | + if format == __Format.VALUE: |
| 1273 | + return {"x": str} |
| 1274 | + else: |
| 1275 | + raise __NotImplementedError(format) |
1251 | 1276 |
|
1252 | | - new_annotate = None |
1253 | | - functype = types.FunctionType |
| 1277 | + annotations = annotationlib.call_annotate_function( |
| 1278 | + annotate, |
| 1279 | + Format.FORWARDREF, |
| 1280 | + ) |
1254 | 1281 |
|
1255 | | - def functiontype(*args, **kwargs): |
1256 | | - nonlocal new_annotate |
1257 | | - new_func = unittest.mock.MagicMock(wraps=functype(*args, **kwargs)) |
1258 | | - new_annotate = new_func |
1259 | | - return new_func |
| 1282 | + self.assertEqual(annotations, {"x": str}) |
1260 | 1283 |
|
1261 | | - with unittest.mock.patch("types.FunctionType", new=functiontype): |
1262 | | - annotations = annotationlib.call_annotate_function( |
1263 | | - annotate, |
1264 | | - Format.FORWARDREF, |
1265 | | - ) |
| 1284 | + def test_user_annotate_string_supported(self): |
| 1285 | + # If Format.STRING is supported prefer it over Format.VALUE |
| 1286 | + def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedError): |
| 1287 | + if format == __Format.VALUE: |
| 1288 | + return {'x': str} |
| 1289 | + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: |
| 1290 | + return {'x': int} |
| 1291 | + elif format == __Format.STRING: |
| 1292 | + return {'x': "float"} |
| 1293 | + else: |
| 1294 | + raise __NotImplementedError(format) |
1266 | 1295 |
|
1267 | | - # The call with Format.VALUE_WITH_FAKE_GLOBALS is not |
1268 | | - # on the original function. |
1269 | | - annotate.assert_has_calls([ |
1270 | | - unittest.mock.call(Format.FORWARDREF), |
1271 | | - unittest.mock.call(Format.VALUE), |
1272 | | - ]) |
| 1296 | + annotations = annotationlib.call_annotate_function( |
| 1297 | + annotate, |
| 1298 | + Format.STRING, |
| 1299 | + ) |
1273 | 1300 |
|
1274 | | - new_annotate.assert_called_once_with(Format.VALUE_WITH_FAKE_GLOBALS) |
| 1301 | + self.assertEqual(annotations, {"x": "float"}) |
1275 | 1302 |
|
1276 | | - self.assertEqual(annotations, {"x": str}) |
| 1303 | + def test_user_annotate_string_fakeglobals(self): |
| 1304 | + # If Format.STRING is not supported but Format.VALUE_WITH_FAKE_GLOBALS is |
| 1305 | + # prefer that over Format.VALUE |
| 1306 | + def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedError): |
| 1307 | + if format == __Format.VALUE: |
| 1308 | + return {'x': str} |
| 1309 | + elif format == __Format.VALUE_WITH_FAKE_GLOBALS: |
| 1310 | + return {'x': int} |
| 1311 | + else: |
| 1312 | + raise __NotImplementedError(format) |
1277 | 1313 |
|
| 1314 | + annotations = annotationlib.call_annotate_function( |
| 1315 | + annotate, |
| 1316 | + Format.STRING, |
| 1317 | + ) |
1278 | 1318 |
|
1279 | | - def test_user_annotate_string(self): |
1280 | | - annotate = self._annotate_mock() |
| 1319 | + self.assertEqual(annotations, {"x": "int"}) |
| 1320 | + |
| 1321 | + def test_user_annotate_string_value_fallback(self): |
| 1322 | + # If Format.STRING and Format.VALUE_WITH_FAKE_GLOBALS are not |
| 1323 | + # supported fall back to Format.VALUE and convert to strings |
| 1324 | + def annotate(format, /, __Format=Format, __NotImplementedError=NotImplementedError): |
| 1325 | + if format == __Format.VALUE: |
| 1326 | + return {"x": str} |
| 1327 | + else: |
| 1328 | + raise __NotImplementedError(format) |
1281 | 1329 |
|
1282 | 1330 | annotations = annotationlib.call_annotate_function( |
1283 | 1331 | annotate, |
1284 | 1332 | Format.STRING, |
1285 | 1333 | ) |
1286 | 1334 |
|
1287 | | - annotate.assert_has_calls([ |
1288 | | - unittest.mock.call(Format.STRING), |
1289 | | - unittest.mock.call(Format.VALUE_WITH_FAKE_GLOBALS), |
1290 | | - unittest.mock.call(Format.VALUE), |
1291 | | - ]) |
1292 | | - |
1293 | 1335 | self.assertEqual(annotations, {"x": "str"}) |
1294 | 1336 |
|
| 1337 | + def test_condition_not_stringified(self): |
| 1338 | + # Make sure the first condition isn't evaluated as True by being converted |
| 1339 | + # to a _Stringifier |
| 1340 | + def annotate(format, /): |
| 1341 | + if format == Format.FORWARDREF: |
| 1342 | + return {"x": str} |
| 1343 | + else: |
| 1344 | + raise NotImplementedError(format) |
| 1345 | + |
| 1346 | + with self.assertRaises(NotImplementedError): |
| 1347 | + _ = annotationlib.call_annotate_function(annotate, Format.STRING) |
| 1348 | + |
1295 | 1349 |
|
1296 | 1350 | class MetaclassTests(unittest.TestCase): |
1297 | 1351 | def test_annotated_meta(self): |
|
0 commit comments