diff --git a/.codegen/.gitignore b/.codegen/.gitignore index a384b9257..77d89d205 100644 --- a/.codegen/.gitignore +++ b/.codegen/.gitignore @@ -8,12 +8,8 @@ jupyter/ codegen-system-prompt.txt # Python cache files -__pycache__/ +**/__pycache__/ *.py[cod] *$py.class *.txt *.pyc - -# Keep codemods -!codemods/ -!codemods/** diff --git a/.codegen/codemods/no_link_backticks/no_link_backticks.py b/.codegen/codemods/no_link_backticks/no_link_backticks.py index b74509cc6..e8cda5323 100644 --- a/.codegen/codemods/no_link_backticks/no_link_backticks.py +++ b/.codegen/codemods/no_link_backticks/no_link_backticks.py @@ -2,7 +2,7 @@ from codegen import Codebase -@codegen.function("no-link-backticks") +@codegen.function(name="no-link-backticks", subdirectories=["test/unit"]) def run(codebase: Codebase): import re @@ -12,6 +12,7 @@ def run(codebase: Codebase): # Iterate over all .mdx files in the codebase for file in codebase.files(extensions=["mdx"]): if file.extension == ".mdx": + print(f"Processing {file.path}") new_content = file.content # Find all markdown links with backticks in link text diff --git a/src/codegen/cli/commands/list/main.py b/src/codegen/cli/commands/list/main.py index 046c1a17e..944828e3c 100644 --- a/src/codegen/cli/commands/list/main.py +++ b/src/codegen/cli/commands/list/main.py @@ -17,10 +17,16 @@ def list_command(): table.add_column("Name", style="cyan") table.add_column("Type", style="magenta") table.add_column("Path", style="dim") + table.add_column("Subdirectories", style="dim") for func in functions: func_type = "Webhook" if func.lint_mode else "Function" - table.add_row(func.name, func_type, str(func.filepath.relative_to(Path.cwd())) if func.filepath else "") + table.add_row( + func.name, + func_type, + str(func.filepath.relative_to(Path.cwd())) if func.filepath else "", + ", ".join(func.subdirectories) if func.subdirectories else "", + ) rich.print(table) rich.print("\nRun a function with:") diff --git a/src/codegen/cli/commands/run/main.py b/src/codegen/cli/commands/run/main.py index 1c393abfe..d34d93784 100644 --- a/src/codegen/cli/commands/run/main.py +++ b/src/codegen/cli/commands/run/main.py @@ -13,14 +13,12 @@ @click.command(name="run") @requires_init @click.argument("label", required=True) -@click.option("--path", type=str, help="Path to build the codebase from. Defaults to the repo root.") @click.option("--web", is_flag=True, help="Run the function on the web service instead of locally") @click.option("--diff-preview", type=int, help="Show a preview of the first N lines of the diff") @click.option("--arguments", type=str, help="Arguments as a json string to pass as the function's 'arguments' parameter") def run_command( session: CodegenSession, label: str, - path: str | None = None, web: bool = False, diff_preview: int | None = None, arguments: str | None = None, @@ -59,4 +57,4 @@ def run_command( else: from codegen.cli.commands.run.run_local import run_local - run_local(session, codemod, diff_preview=diff_preview, path=path) + run_local(session, codemod, diff_preview=diff_preview) diff --git a/src/codegen/cli/commands/run/run_local.py b/src/codegen/cli/commands/run/run_local.py index 4f030b4fa..8e2afdfea 100644 --- a/src/codegen/cli/commands/run/run_local.py +++ b/src/codegen/cli/commands/run/run_local.py @@ -6,10 +6,16 @@ from codegen.cli.auth.session import CodegenSession from codegen.cli.utils.function_finder import DecoratedFunction +from codegen.git.repo_operator.repo_operator import RepoOperator +from codegen.git.schemas.repo_config import RepoConfig +from codegen.sdk.codebase.config import ProjectConfig from codegen.sdk.core.codebase import Codebase -def parse_codebase(repo_root: Path) -> Codebase: +def parse_codebase( + repo_path: Path, + subdirectories: list[str] | None = None, +) -> Codebase: """Parse the codebase at the given root. Args: @@ -18,7 +24,14 @@ def parse_codebase(repo_root: Path) -> Codebase: Returns: Parsed Codebase object """ - codebase = Codebase(repo_root) + codebase = Codebase( + projects=[ + ProjectConfig( + repo_operator=RepoOperator(repo_config=RepoConfig.from_repo_path(repo_path=repo_path)), + subdirectories=subdirectories, + ) + ] + ) return codebase @@ -26,7 +39,6 @@ def run_local( session: CodegenSession, function: DecoratedFunction, diff_preview: int | None = None, - path: Path | None = None, ) -> None: """Run a function locally against the codebase. @@ -36,10 +48,8 @@ def run_local( diff_preview: Number of lines of diff to preview (None for all) """ # Parse codebase and run - codebase_path = f"{session.repo_path}/{path}" if path else session.repo_path - - with Status(f"[bold]Parsing codebase at {codebase_path} ...", spinner="dots") as status: - codebase = parse_codebase(codebase_path) + with Status(f"[bold]Parsing codebase at {session.repo_path} with subdirectories {function.subdirectories or 'ALL'} ...", spinner="dots") as status: + codebase = parse_codebase(repo_path=session.repo_path, subdirectories=function.subdirectories) status.update("[bold green]✓ Parsed codebase") status.update("[bold]Running codemod...") diff --git a/src/codegen/cli/sdk/decorator.py b/src/codegen/cli/sdk/decorator.py index e4e93e956..3a9d7938b 100644 --- a/src/codegen/cli/sdk/decorator.py +++ b/src/codegen/cli/sdk/decorator.py @@ -15,11 +15,13 @@ def __init__( self, name: str, *, + subdirectories: list[str] | None = None, webhook_config: dict | None = None, lint_mode: bool = False, lint_user_whitelist: Sequence[str] | None = None, ): self.name = name + self.subdirectories = subdirectories self.func: Callable | None = None self.params_type = None self.webhook_config = webhook_config @@ -42,7 +44,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: return wrapper -def function(name: str) -> DecoratedFunction: +def function(name: str, subdirectories: list[str] | None = None) -> DecoratedFunction: """Decorator for codegen functions. Args: @@ -54,7 +56,7 @@ def run(codebase): pass """ - return DecoratedFunction(name) + return DecoratedFunction(name=name, subdirectories=subdirectories) def webhook( diff --git a/src/codegen/cli/utils/function_finder.py b/src/codegen/cli/utils/function_finder.py index 0d77a71fd..325dece3f 100644 --- a/src/codegen/cli/utils/function_finder.py +++ b/src/codegen/cli/utils/function_finder.py @@ -14,6 +14,7 @@ class DecoratedFunction: source: str lint_mode: bool lint_user_whitelist: list[str] + subdirectories: list[str] | None = None filepath: Path | None = None parameters: list[tuple[str, str | None]] = dataclasses.field(default_factory=list) arguments_type_schema: dict | None = None @@ -83,6 +84,20 @@ class CodegenFunctionVisitor(ast.NodeVisitor): def __init__(self): self.functions: list[DecoratedFunction] = [] + def get_function_name(self, node: ast.Call) -> str: + keywords = {k.arg: k.value for k in node.keywords} + if "name" in keywords: + return ast.literal_eval(keywords["name"]) + return ast.literal_eval(node.args[0]) + + def get_subdirectories(self, node: ast.Call) -> list[str] | None: + keywords = {k.arg: k.value for k in node.keywords} + if "subdirectories" in keywords: + return ast.literal_eval(keywords["subdirectories"]) + if len(node.args) > 1: + return ast.literal_eval(node.args[1]) + return None + def get_function_body(self, node: ast.FunctionDef) -> str: """Extract and unindent the function body.""" # Get the start and end positions of the function body @@ -178,7 +193,7 @@ def visit_FunctionDef(self, node): for decorator in node.decorator_list: if ( isinstance(decorator, ast.Call) - and len(decorator.args) >= 1 + and (len(decorator.args) > 0 or len(decorator.keywords) > 0) and ( # Check if it's a direct codegen.X call (isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name) and decorator.func.value.id == "codegen") @@ -188,7 +203,8 @@ def visit_FunctionDef(self, node): ) ): # Get the function name from the decorator argument - func_name = ast.literal_eval(decorator.args[0]) + func_name = self.get_function_name(decorator) + subdirectories = self.get_subdirectories(decorator) # Get additional metadata for webhook lint_mode = decorator.func.attr == "webhook" @@ -201,7 +217,16 @@ def visit_FunctionDef(self, node): # Get just the function body, unindented body_source = self.get_function_body(node) parameters = self.get_function_parameters(node) - self.functions.append(DecoratedFunction(name=func_name, source=body_source, lint_mode=lint_mode, lint_user_whitelist=lint_user_whitelist, parameters=parameters)) + self.functions.append( + DecoratedFunction( + name=func_name, + subdirectories=subdirectories, + source=body_source, + lint_mode=lint_mode, + lint_user_whitelist=lint_user_whitelist, + parameters=parameters, + ) + ) def _has_codegen_root(self, node): """Recursively check if an AST node chain starts with codegen."""