diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 6c8d03319893..6b5ea0d5af61 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -670,7 +670,7 @@ def _verify_arg_default_value( stub_arg: nodes.Argument, runtime_arg: inspect.Parameter ) -> Iterator[str]: """Checks whether argument default values are compatible.""" - if runtime_arg.default != inspect.Parameter.empty: + if runtime_arg.default is not inspect.Parameter.empty: if stub_arg.kind.is_required(): yield ( f'runtime argument "{runtime_arg.name}" ' @@ -705,18 +705,26 @@ def _verify_arg_default_value( stub_default is not UNKNOWN and stub_default is not ... and runtime_arg.default is not UNREPRESENTABLE - and ( - stub_default != runtime_arg.default - # We want the types to match exactly, e.g. in case the stub has - # True and the runtime has 1 (or vice versa). - or type(stub_default) is not type(runtime_arg.default) - ) ): - yield ( - f'runtime argument "{runtime_arg.name}" ' - f"has a default value of {runtime_arg.default!r}, " - f"which is different from stub argument default {stub_default!r}" - ) + defaults_match = True + # We want the types to match exactly, e.g. in case the stub has + # True and the runtime has 1 (or vice versa). + if type(stub_default) is not type(runtime_arg.default): + defaults_match = False + else: + try: + defaults_match = bool(stub_default == runtime_arg.default) + except Exception: + # Exception can be raised in bool dunder method (e.g. numpy arrays) + # At this point, consider the default to be different, it is probably + # too complex to put in a stub anyway. + defaults_match = False + if not defaults_match: + yield ( + f'runtime argument "{runtime_arg.name}" ' + f"has a default value of {runtime_arg.default!r}, " + f"which is different from stub argument default {stub_default!r}" + ) else: if stub_arg.kind.is_optional(): yield ( @@ -758,7 +766,7 @@ def get_type(arg: Any) -> str | None: def has_default(arg: Any) -> bool: if isinstance(arg, inspect.Parameter): - return bool(arg.default != inspect.Parameter.empty) + return arg.default is not inspect.Parameter.empty if isinstance(arg, nodes.Argument): return arg.kind.is_optional() raise AssertionError @@ -1628,13 +1636,13 @@ def anytype() -> mypy.types.AnyType: arg_names.append( None if arg.kind == inspect.Parameter.POSITIONAL_ONLY else arg.name ) - has_default = arg.default == inspect.Parameter.empty + no_default = arg.default is inspect.Parameter.empty if arg.kind == inspect.Parameter.POSITIONAL_ONLY: - arg_kinds.append(nodes.ARG_POS if has_default else nodes.ARG_OPT) + arg_kinds.append(nodes.ARG_POS if no_default else nodes.ARG_OPT) elif arg.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: - arg_kinds.append(nodes.ARG_POS if has_default else nodes.ARG_OPT) + arg_kinds.append(nodes.ARG_POS if no_default else nodes.ARG_OPT) elif arg.kind == inspect.Parameter.KEYWORD_ONLY: - arg_kinds.append(nodes.ARG_NAMED if has_default else nodes.ARG_NAMED_OPT) + arg_kinds.append(nodes.ARG_NAMED if no_default else nodes.ARG_NAMED_OPT) elif arg.kind == inspect.Parameter.VAR_POSITIONAL: arg_kinds.append(nodes.ARG_STAR) elif arg.kind == inspect.Parameter.VAR_KEYWORD: diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 6dc1feb67089..f099ebdc55a5 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -529,6 +529,18 @@ def f11(text=None) -> None: pass error="f11", ) + # Simulate numpy ndarray.__bool__ that raises an error + yield Case( + stub="def f12(x=1): ...", + runtime=""" + class _ndarray: + def __eq__(self, obj): return self + def __bool__(self): raise ValueError + def f12(x=_ndarray()) -> None: pass + """, + error="f12", + ) + @collect_cases def test_static_class_method(self) -> Iterator[Case]: yield Case(