Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions astroid/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,14 @@ def _infer_type_new_call(
raise InferenceError(context=context) from e
if not isinstance(mcs, nodes.ClassDef):
# Not a valid first argument.
return None
raise InferenceError(
"type.__new__() requires a class for metaclass", context=context
)
if not mcs.is_subtype_of("builtins.type"):
# Not a valid metaclass.
return None
raise InferenceError(
"type.__new__() metaclass must be a subclass of type", context=context
)

# Verify the name
try:
Expand All @@ -583,10 +587,14 @@ def _infer_type_new_call(
raise InferenceError(context=context) from e
if not isinstance(name, nodes.Const):
# Not a valid name, needs to be a const.
return None
raise InferenceError(
"type.__new__() requires a constant for name", context=context
)
if not isinstance(name.value, str):
# Needs to be a string.
return None
raise InferenceError(
"type.__new__() requires a string for name", context=context
)

# Verify the bases
try:
Expand All @@ -595,14 +603,18 @@ def _infer_type_new_call(
raise InferenceError(context=context) from e
if not isinstance(bases, nodes.Tuple):
# Needs to be a tuple.
return None
raise InferenceError(
"type.__new__() requires a tuple for bases", context=context
)
try:
inferred_bases = [next(elt.infer(context=context)) for elt in bases.elts]
except StopIteration as e:
raise InferenceError(context=context) from e
if any(not isinstance(base, nodes.ClassDef) for base in inferred_bases):
# All the bases needs to be Classes
return None
raise InferenceError(
"type.__new__() requires classes for bases", context=context
)

# Verify the attributes.
try:
Expand All @@ -611,7 +623,9 @@ def _infer_type_new_call(
raise InferenceError(context=context) from e
if not isinstance(attrs, nodes.Dict):
# Needs to be a dictionary.
return None
raise InferenceError(
"type.__new__() requires a dict for attrs", context=context
)
cls_locals: dict[str, list[InferenceResult]] = collections.defaultdict(list)
for key, value in attrs.items:
try:
Expand Down Expand Up @@ -664,9 +678,13 @@ def infer_call_result(
and self.bound.name == "type"
and self.name == "__new__"
and isinstance(caller, nodes.Call)
and len(caller.args) == 4
):
# Check if we have a ``type.__new__(mcs, name, bases, attrs)`` call.
if len(caller.args) != 4:
raise InferenceError(
f"type.__new__() requires 4 arguments, got {len(caller.args)}",
context=context,
)
new_cls = self._infer_type_new_call(caller, context)
if new_cls:
return iter((new_cls,))
Expand Down