|
| 1 | +from unittest.mock import Mock, patch |
| 2 | + |
| 3 | +from parser.common import StrRegion |
| 4 | +from parser.typecheck.typecheck import NameResolver |
| 5 | +from test.common import CommonTestCase |
| 6 | + |
| 7 | + |
| 8 | +class BoundMock(Mock): |
| 9 | + # Not-so-black magic to allow setting the method on the class while ensuring |
| 10 | + # that the wrapper receives the correct `self` value (builtin unittest is |
| 11 | + # a bit broken in this regard, as it passes no self value at all) |
| 12 | + def __init__(self, *args, wraps=None, **kwargs): |
| 13 | + if wraps: |
| 14 | + def wraps_2(*a, **k): # I hope we don't have to pickle this anywhere... |
| 15 | + if (inst := self.__dict__['_inst']) is not None: |
| 16 | + return wraps(inst, *a, **k) |
| 17 | + else: |
| 18 | + return wraps(*a, **k) |
| 19 | + else: |
| 20 | + wraps_2 = None |
| 21 | + super().__init__(*args, wraps=wraps_2, **kwargs) |
| 22 | + self.__dict__['_inst'] = None |
| 23 | + |
| 24 | + def __get__(self, instance, owner=None): |
| 25 | + # This very naive method of finding out the proper instance works |
| 26 | + # because __get__ will be called again every time this is accessed |
| 27 | + if owner is not None: # if accessed on instance |
| 28 | + self.__dict__['_inst'] = instance |
| 29 | + else: |
| 30 | + self.__dict__['_inst'] = None |
| 31 | + return self |
| 32 | + |
| 33 | + |
| 34 | +class TestNameResolve(CommonTestCase): |
| 35 | + def test_top_scope_attr(self): |
| 36 | + src = 'let a = 8, b = 5; a += b; def c(val param) {c(param, a, b);}' |
| 37 | + orig = NameResolver._init # Reliably called exactly once in slow path of .run() |
| 38 | + m = BoundMock(spec_set=orig, wraps=orig) |
| 39 | + with patch.object(NameResolver, '_init', new_callable=lambda: m): |
| 40 | + nr = self.getNameResolver(src) |
| 41 | + self.assertIsNone(nr.top_scope) |
| 42 | + m.assert_not_called() |
| 43 | + v = nr.run() |
| 44 | + self.assertIs(v, nr.top_scope) |
| 45 | + m.assert_called_once() |
| 46 | + v2 = nr.run() |
| 47 | + self.assertIs(v2, v) |
| 48 | + m.assert_called_once() # Still only once |
| 49 | + |
| 50 | + |
| 51 | +class TestNameResolveErrors(CommonTestCase): |
| 52 | + def test_undefined_var(self): |
| 53 | + err = self.assertNameResolveError('foo = 9;') |
| 54 | + self.assertContains(err.msg, "Name 'foo' is not defined") |
| 55 | + self.assertEqual(StrRegion(0, 3), err.region) |
0 commit comments