diff --git a/mitogen/core.py b/mitogen/core.py index 441743d4c..ad4a3ac88 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -792,8 +792,19 @@ def save_exc_inst(self, obj): # In 3.x Unpickler is a class exposing find_class as an overridable, but it # cannot be overridden without subclassing. class _Unpickler(pickle.Unpickler): + + def __init__(self, *args, insecure=False, **kwargs): + super().__init__(*args, **kwargs) + self.__insecure = insecure + def find_class(self, module, func): - return self.find_global(module, func) + try: + return self.find_global(module, func) + except Exception as error: + if not self.__insecure: + raise error + return super().find_class(module, func) + pickle__dumps = pickle.dumps elif PY24: # On Python 2.4, we must use a pure-Python pickler. @@ -977,7 +988,7 @@ def _throw_dead(self): else: raise ChannelError(ChannelError.remote_msg) - def unpickle(self, throw=True, throw_dead=True): + def unpickle(self, throw=True, throw_dead=True, *, insecure=False): """ Unpickle :attr:`data`, optionally raising any exceptions present. @@ -985,6 +996,9 @@ def unpickle(self, throw=True, throw_dead=True): If :data:`True`, raise exceptions, otherwise it is the caller's responsibility. + :param bool insecure: + If :data:`True`, also use possibly unsecure unpickling methods. + :raises CallError: The serialized data contained CallError exception. :raises ChannelError: @@ -997,7 +1011,7 @@ def unpickle(self, throw=True, throw_dead=True): obj = self._unpickled if obj is Message._unpickled: fp = BytesIO(self.data) - unpickler = _Unpickler(fp, **self.UNPICKLER_KWARGS) + unpickler = _Unpickler(fp, insecure=insecure, **self.UNPICKLER_KWARGS) unpickler.find_global = self._find_global try: # Must occur off the broker thread. @@ -3844,7 +3858,7 @@ def forget_chain(cls, chain_id, econtext): econtext.dispatcher._error_by_chain_id.pop(chain_id, None) def _parse_request(self, msg): - data = msg.unpickle(throw=False) + data = msg.unpickle(throw=False, insecure=True) _v and LOG.debug('%r: dispatching %r', self, data) chain_id, modname, klass, func, args, kwargs = data diff --git a/tests/call_function_test.py b/tests/call_function_test.py index 1e838bdad..b60108e9e 100644 --- a/tests/call_function_test.py +++ b/tests/call_function_test.py @@ -42,6 +42,14 @@ class TargetClass: def add_numbers_with_offset(cls, x, y): return cls.offset + x + y + @classmethod + def passing_crazy_type(cls, crazy_cls): + return crazy_cls.__name__ + + @classmethod + def passing_crazy_type_instance(cls, crazy): + return crazy.__class__.__name__ + class CallFunctionTest(testlib.RouterMixin, testlib.TestCase): @@ -58,6 +66,18 @@ def test_succeeds_class_method(self): 103, ) + def test_succeeds_passing_class(self): + self.assertEqual( + self.local.call(TargetClass.passing_crazy_type, CrazyType), + 'CrazyType' + ) + + def test_succeeds_passing_class_instance(self): + self.assertEqual( + self.local.call(TargetClass.passing_crazy_type_instance, CrazyType()), + 'CrazyType' + ) + def test_crashes(self): exc = self.assertRaises(mitogen.core.CallError, lambda: self.local.call(function_that_fails))