Skip to content

Commit 00bbd68

Browse files
committed
Refactored Pydantic model inheritance detection
1 parent b522da8 commit 00bbd68

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

mypy/stubgen.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -861,13 +861,8 @@ def visit_class_def(self, o: ClassDef) -> None:
861861
if self.analyzed and (spec := find_dataclass_transform_spec(o)):
862862
self.processing_dataclass = True
863863
self.dataclass_field_specifier = spec.field_specifiers
864-
for base_type_expr in o.base_type_exprs:
865-
if (
866-
isinstance(base_type_expr, (NameExpr, MemberExpr))
867-
and self.get_fullname(base_type_expr) == "pydantic.BaseModel"
868-
):
869-
self.processing_pydantic_model = True
870-
break
864+
if self._inherits_from_pydantic_basemodel(o):
865+
self.processing_pydantic_model = True
871866
super().visit_class_def(o)
872867
self.dedent()
873868
self._vars.pop()
@@ -887,6 +882,20 @@ def visit_class_def(self, o: ClassDef) -> None:
887882
self.processing_enum = False
888883
self.processing_pydantic_model = False
889884

885+
def _inherits_from_pydantic_basemodel(self, class_def: ClassDef) -> bool:
886+
"""Check if a class directly or indirectly inherits from pydantic.BaseModel"""
887+
for base_type_expr in class_def.base_type_exprs:
888+
if (
889+
isinstance(base_type_expr, (NameExpr, MemberExpr))
890+
and self.get_fullname(base_type_expr) == "pydantic.BaseModel"
891+
):
892+
return True
893+
if self.analyzed and class_def.info:
894+
for base_class in class_def.info.mro:
895+
if base_class.fullname == "pydantic.BaseModel":
896+
return True
897+
return False
898+
890899
def get_base_types(self, cdef: ClassDef) -> list[str]:
891900
"""Get list of base classes for a class."""
892901
base_types: list[str] = []

test-data/unit/stubgen.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4801,7 +4801,7 @@ class User(BaseModel):
48014801
age: int
48024802
address: Address | None = ...
48034803

4804-
[case testPydanticBaseModelInheritance]
4804+
[case testPydanticBaseModelInheritance_semanal]
48054805
from pydantic import BaseModel
48064806

48074807
class BaseUser(BaseModel):

0 commit comments

Comments
 (0)