diff --git a/codeflash/code_utils/static_analysis.py b/codeflash/code_utils/static_analysis.py index dbddb59f5..ed8e5f7db 100644 --- a/codeflash/code_utils/static_analysis.py +++ b/codeflash/code_utils/static_analysis.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, ConfigDict, field_validator +from codeflash.models.models import FunctionParent + if TYPE_CHECKING: from codeflash.models.models import FunctionParent @@ -139,14 +141,20 @@ def get_first_top_level_function_or_method_ast( def function_kind(node: ast.FunctionDef | ast.AsyncFunctionDef, parents: list[FunctionParent]) -> FunctionKind | None: - if not parents or parents[0].type in ["FunctionDef", "AsyncFunctionDef"]: + # Fast exit if no parents or parent is function + if not parents: + return FunctionKind.FUNCTION + parent_type = parents[0].type + if parent_type in _FUNCTION_OR_ASYNC: return FunctionKind.FUNCTION - if parents[0].type == "ClassDef": + if parent_type == "ClassDef": + # Use a set for quick id matching, short-circuit as soon as possible for decorator in node.decorator_list: if isinstance(decorator, ast.Name): - if decorator.id == "classmethod": + did = decorator.id + if did == "classmethod": return FunctionKind.CLASS_METHOD - if decorator.id == "staticmethod": + if did == "staticmethod": return FunctionKind.STATIC_METHOD return FunctionKind.INSTANCE_METHOD return None @@ -159,3 +167,6 @@ def has_typed_parameters(node: ast.FunctionDef | ast.AsyncFunctionDef, parents: if kind in [FunctionKind.CLASS_METHOD, FunctionKind.INSTANCE_METHOD]: return all(arg.annotation for arg in node.args.args[1:]) return False + + +_FUNCTION_OR_ASYNC = {"FunctionDef", "AsyncFunctionDef"}