|
| 1 | +"""Mypy plugin for CrewAI decorator type checking. |
| 2 | +
|
| 3 | +This plugin informs mypy about attributes injected by the @CrewBase decorator. |
| 4 | +""" |
| 5 | + |
| 6 | +from collections.abc import Callable |
| 7 | + |
| 8 | +from mypy.nodes import MDEF, SymbolTableNode, Var |
| 9 | +from mypy.plugin import ClassDefContext, Plugin |
| 10 | +from mypy.types import AnyType, TypeOfAny |
| 11 | + |
| 12 | + |
| 13 | +class CrewAIPlugin(Plugin): |
| 14 | + """Mypy plugin that handles @CrewBase decorator attribute injection.""" |
| 15 | + |
| 16 | + def get_class_decorator_hook( |
| 17 | + self, fullname: str |
| 18 | + ) -> Callable[[ClassDefContext], None] | None: |
| 19 | + """Return hook for class decorators. |
| 20 | +
|
| 21 | + Args: |
| 22 | + fullname: Fully qualified name of the decorator. |
| 23 | +
|
| 24 | + Returns: |
| 25 | + Hook function if this is a CrewBase decorator, None otherwise. |
| 26 | + """ |
| 27 | + if fullname in ("crewai.project.CrewBase", "crewai.project.crew_base.CrewBase"): |
| 28 | + return self._crew_base_hook |
| 29 | + return None |
| 30 | + |
| 31 | + @staticmethod |
| 32 | + def _crew_base_hook(ctx: ClassDefContext) -> None: |
| 33 | + """Add injected attributes to @CrewBase decorated classes. |
| 34 | +
|
| 35 | + Args: |
| 36 | + ctx: Context for the class being decorated. |
| 37 | + """ |
| 38 | + any_type = AnyType(TypeOfAny.explicit) |
| 39 | + str_type = ctx.api.named_type("builtins.str") |
| 40 | + dict_type = ctx.api.named_type("builtins.dict", [str_type, any_type]) |
| 41 | + agents_config_var = Var("agents_config", dict_type) |
| 42 | + agents_config_var.info = ctx.cls.info |
| 43 | + agents_config_var._fullname = f"{ctx.cls.info.fullname}.agents_config" |
| 44 | + ctx.cls.info.names["agents_config"] = SymbolTableNode(MDEF, agents_config_var) |
| 45 | + tasks_config_var = Var("tasks_config", dict_type) |
| 46 | + tasks_config_var.info = ctx.cls.info |
| 47 | + tasks_config_var._fullname = f"{ctx.cls.info.fullname}.tasks_config" |
| 48 | + ctx.cls.info.names["tasks_config"] = SymbolTableNode(MDEF, tasks_config_var) |
| 49 | + |
| 50 | + |
| 51 | +def plugin(_: str) -> type[Plugin]: |
| 52 | + """Entry point for mypy plugin. |
| 53 | +
|
| 54 | + Args: |
| 55 | + _: Mypy version string. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + Plugin class. |
| 59 | + """ |
| 60 | + return CrewAIPlugin |
0 commit comments