diff --git a/src/griffe_pydantic/_internal/common.py b/src/griffe_pydantic/_internal/common.py index 4e66ad1..39a4135 100644 --- a/src/griffe_pydantic/_internal/common.py +++ b/src/griffe_pydantic/_internal/common.py @@ -1,17 +1,32 @@ from __future__ import annotations +import importlib import json +import sys from functools import partial from typing import TYPE_CHECKING +from griffe import get_logger + if TYPE_CHECKING: from collections.abc import Sequence from griffe import Attribute, Class, Function from pydantic import BaseModel +_DEFAULT_BASES = ( + "pydantic.BaseModel", + "pydantic.main.BaseModel", + "pydantic_settings.BaseSettings", + "pydantic_settings.main.BaseSettings", + "sqlmodel.SQLModel", + "sqlmodel.main.SQLModel", +) + + _self_namespace = "griffe_pydantic" _mkdocstrings_namespace = "mkdocstrings" +_logger = get_logger(__name__) _field_constraints = { "gt", @@ -77,3 +92,30 @@ def _process_function(func: Function, cls: Class, fields: Sequence[str]) -> None for target in targets: target.extra[_self_namespace].setdefault("validators", []) target.extra[_self_namespace]["validators"].append(func) + + +def _import_from_name(name: str) -> type[BaseModel]: + """Given a fully-qualified `package.module.Class` name, return the imported class.""" + module_name, _, class_name = name.rpartition(".") + module = sys.modules.get(module_name, importlib.import_module(module_name)) + try: + return getattr(module, class_name) + except AttributeError as e: + raise AttributeError(f"No class {class_name} in module {module}") from e + + +def _import_bases(names: tuple[str, ...]) -> tuple[type[BaseModel], ...]: + """Import a set of bases from fully-qualified `package.module.Class` names. + + Does not raise for import errors, + since we don't expect all possible bases to be present. + """ + bases = [] + for name in names: + try: + bases.append(_import_from_name(name)) + except ImportError: + # fine, we expect some of the defaults to fail, we only care if we have none + _logger.debug("Could not import %s", name) + + return tuple(bases) diff --git a/src/griffe_pydantic/_internal/extension.py b/src/griffe_pydantic/_internal/extension.py index d21ca05..b9e22ce 100644 --- a/src/griffe_pydantic/_internal/extension.py +++ b/src/griffe_pydantic/_internal/extension.py @@ -10,7 +10,7 @@ get_logger, ) -from griffe_pydantic._internal import dynamic, static +from griffe_pydantic._internal import common, dynamic, static if TYPE_CHECKING: from griffe import ObjectNode @@ -22,14 +22,26 @@ class PydanticExtension(Extension): """Griffe extension for Pydantic.""" - def __init__(self, *, schema: bool = False) -> None: + def __init__( + self, + *, + schema: bool = False, + bases: tuple[str, ...] | list[str] = common._DEFAULT_BASES, + include_bases: tuple[str, ...] | list[str] | None = None, + ) -> None: """Initialize the extension. Parameters: schema: Whether to compute and store the JSON schema of models. + bases: Tuple of complete `package.module.Class` references to base classes that should be considered + pydantic models. Declaring this *replaces* the default bases. + include_bases: *Additional* base classes to consider as pydantic models, including the defaults. """ super().__init__() self._schema = schema + self._bases = tuple(bases) + if include_bases: + self._bases += tuple(include_bases) self._processed: set[str] = set() self._recorded: list[tuple[ObjectNode, Class]] = [] @@ -38,7 +50,7 @@ def on_package_loaded(self, *, pkg: Module, **kwargs: Any) -> None: # noqa: ARG for node, cls in self._recorded: self._processed.add(cls.canonical_path) dynamic._process_class(node.obj, cls, processed=self._processed, schema=self._schema) - static._process_module(pkg, processed=self._processed, schema=self._schema) + static._process_module(pkg, processed=self._processed, schema=self._schema, bases=self._bases) def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: Any) -> None: # noqa: ARG002 """Detect and prepare Pydantic models.""" @@ -46,11 +58,13 @@ def on_class_instance(self, *, node: ast.AST | ObjectNode, cls: Class, **kwargs: if isinstance(node, ast.AST): return - try: - import pydantic - except ImportError: - _logger.warning("could not import pydantic - models will not be detected") + bases = common._import_bases(self._bases) + if not bases: + _logger.warning( + "could not import any expected model base - models will not be detected. \nexpected: %s", + self._bases, + ) return - if issubclass(node.obj, pydantic.BaseModel): + if issubclass(node.obj, bases): self._recorded.append((node, cls)) diff --git a/src/griffe_pydantic/_internal/static.py b/src/griffe_pydantic/_internal/static.py index f0caee1..0d36db9 100644 --- a/src/griffe_pydantic/_internal/static.py +++ b/src/griffe_pydantic/_internal/static.py @@ -29,7 +29,7 @@ _logger = get_logger(__name__) -def _inherits_pydantic(cls: Class) -> bool: +def _inherits_pydantic(cls: Class, bases: tuple[str, ...] = common._DEFAULT_BASES) -> bool: """Tell whether a class inherits from a Pydantic model. Parameters: @@ -41,10 +41,10 @@ def _inherits_pydantic(cls: Class) -> bool: for base in cls.bases: if isinstance(base, (ExprName, Expr)): base = base.canonical_path # noqa: PLW2901 - if base in {"pydantic.BaseModel", "pydantic.main.BaseModel"}: + if base in bases: return True - return any(_inherits_pydantic(parent_class) for parent_class in cls.mro()) + return any(_inherits_pydantic(parent_class, bases) for parent_class in cls.mro()) def _pydantic_validator(func: Function) -> ExprCall | None: @@ -141,12 +141,18 @@ def _process_function(func: Function, cls: Class, *, processed: set[str]) -> Non common._process_function(func, cls, fields) -def _process_class(cls: Class, *, processed: set[str], schema: bool = False) -> None: +def _process_class( + cls: Class, + *, + processed: set[str], + schema: bool = False, + bases: tuple[str, ...] = common._DEFAULT_BASES, +) -> None: """Finalize the Pydantic model data.""" if cls.canonical_path in processed: return - if not _inherits_pydantic(cls): + if not _inherits_pydantic(cls, bases): return processed.add(cls.canonical_path) @@ -182,6 +188,7 @@ def _process_module( *, processed: set[str], schema: bool = False, + bases: tuple[str, ...] = common._DEFAULT_BASES, ) -> None: """Handle Pydantic models in a module.""" if mod.canonical_path in processed: @@ -191,7 +198,7 @@ def _process_module( for cls in mod.classes.values(): # Don't process aliases, real classes will be processed at some point anyway. if not cls.is_alias: - _process_class(cls, processed=processed, schema=schema) + _process_class(cls, processed=processed, schema=schema, bases=bases) for submodule in mod.modules.values(): - _process_module(submodule, processed=processed, schema=schema) + _process_module(submodule, processed=processed, schema=schema, bases=bases) diff --git a/tests/test_extension.py b/tests/test_extension.py index de582b3..76ac39e 100644 --- a/tests/test_extension.py +++ b/tests/test_extension.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from textwrap import dedent from typing import TYPE_CHECKING import pytest @@ -232,3 +233,61 @@ class B(BaseModel, A): extensions=Extensions(PydanticExtension(schema=False)), ) as package: assert "pydantic-field" in package["B.a"].labels + + +@pytest.mark.parametrize("analysis", ["static", "dynamic"]) +@pytest.mark.parametrize("base_mode", ["bases", "include_bases"]) +def test_detect_custom_bases(analysis: str, base_mode: str) -> None: + """We can detect pydantic models with non-standard bases as specified by config.""" + package_name = "package" + module_name = "builtins" + class_name = "object" + code = dedent(f""" + from {module_name} import {class_name} + + class ExampleParentModel({class_name}): + pass + + class ExampleModel(ExampleParentModel): + pass + """) + + fake_module = dedent(f""" + class {class_name}: + pass + """) + + extension_kwargs = {base_mode: [".".join([module_name, class_name])]} + + loader = { + "static": temporary_visited_package, + "dynamic": temporary_inspected_package, + }[analysis] + with loader( + package_name, + modules={"__init__.py": code, module_name + ".py": fake_module}, + extensions=Extensions(PydanticExtension(**extension_kwargs)), # type: ignore[arg-type] + ) as package: + assert "ExampleParentModel" in package.classes + assert "ExampleModel" in package.classes + assert package.classes["ExampleParentModel"].labels == {"pydantic-model"} + assert package.classes["ExampleModel"].labels == {"pydantic-model"} + + +@pytest.mark.parametrize("analysis", ["static", "dynamic"]) +def test_replace_default_bases(analysis: str) -> None: + """When we replace the bases, pydantic models should no longer be annotated.""" + loader = { + "static": temporary_visited_package, + "dynamic": temporary_inspected_package, + }[analysis] + + with loader( + "package", + modules={"__init__.py": code}, + extensions=Extensions(PydanticExtension(bases=("fake.fakepackage.NoExisty",))), + ) as package: + assert "ExampleParentModel" in package.classes + assert "ExampleModel" in package.classes + assert not package.classes["ExampleParentModel"].labels + assert not package.classes["ExampleModel"].labels