diff --git a/ruff.toml b/ruff.toml index 64eb33912..338b939e4 100644 --- a/ruff.toml +++ b/ruff.toml @@ -6,7 +6,20 @@ exclude = [ ] unsafe-fixes = true [lint] -select = ["F", "E", "W", "I", "UP", "D", "RUF", "DTZ", "TC", "EM", "TRY400"] +select = [ + "F", + "E", + "W", + "I", + "UP", + "D", + "YTT", + "RUF", + "DTZ", + "TC", + "EM", + "TRY400", +] ignore = [ "D100", "D101", @@ -216,3 +229,5 @@ extend-generics = [ ] [lint.isort] known-first-party = ["codegen"] +[format] +docstring-code-format = true diff --git a/scripts/profiling/profile.py b/scripts/profiling/profile.py index 4b0264a3e..c867fcf5c 100644 --- a/scripts/profiling/profile.py +++ b/scripts/profiling/profile.py @@ -5,7 +5,7 @@ import typer -def profile(repo: str, memory: bool = False): +def profile(repo: str, memory: bool = False) -> None: type = "mem" if memory else "cpu" base = f".profiles/{type}/{repo}" os.makedirs(base, exist_ok=True) @@ -37,7 +37,7 @@ def profile(repo: str, memory: bool = False): test.kill() -def main(): +def main() -> None: typer.run(profile) diff --git a/src/codegen/cli/api/client.py b/src/codegen/cli/api/client.py index 442928231..d31982b9c 100644 --- a/src/codegen/cli/api/client.py +++ b/src/codegen/cli/api/client.py @@ -55,7 +55,7 @@ class RestAPI: auth_token: str | None = None - def __init__(self, auth_token: str): + def __init__(self, auth_token: str) -> None: self.auth_token = auth_token def _get_headers(self) -> dict[str, str]: diff --git a/src/codegen/cli/api/modal.py b/src/codegen/cli/api/modal.py index 16b4e3d62..712dafe70 100644 --- a/src/codegen/cli/api/modal.py +++ b/src/codegen/cli/api/modal.py @@ -2,7 +2,7 @@ from codegen.cli.env.global_env import global_env -def get_modal_workspace(): +def get_modal_workspace() -> str: match global_env.ENV: case Environment.PRODUCTION: return "codegen-sh" diff --git a/src/codegen/cli/auth/session.py b/src/codegen/cli/auth/session.py index caba48958..dd4aa460f 100644 --- a/src/codegen/cli/auth/session.py +++ b/src/codegen/cli/auth/session.py @@ -46,7 +46,7 @@ class CodegenSession: _identity: Identity | None = None _profile: UserProfile | None = None - def __init__(self, token: str | None = None): + def __init__(self, token: str | None = None) -> None: self.token = token or get_current_token() @property diff --git a/src/codegen/cli/auth/token_manager.py b/src/codegen/cli/auth/token_manager.py index 04d61c823..c55d20528 100644 --- a/src/codegen/cli/auth/token_manager.py +++ b/src/codegen/cli/auth/token_manager.py @@ -9,12 +9,12 @@ class TokenManager: # Simple token manager to store and retrieve tokens. # This manager checks if the token is expired before retrieval. # TODO: add support for refreshing token and re authorization via supabase oauth - def __init__(self): + def __init__(self) -> None: self.config_dir = CONFIG_DIR self.token_file = AUTH_FILE self._ensure_config_dir() - def _ensure_config_dir(self): + def _ensure_config_dir(self) -> None: """Create config directory if it doesn't exist.""" if not os.path.exists(self.config_dir): Path(self.config_dir).mkdir(parents=True, exist_ok=True) diff --git a/src/codegen/cli/cli.py b/src/codegen/cli/cli.py index 4cd767bd8..b7189670a 100644 --- a/src/codegen/cli/cli.py +++ b/src/codegen/cli/cli.py @@ -21,7 +21,7 @@ @click.group() @click.version_option(prog_name="codegen", message="%(version)s") -def main(): +def main() -> None: """Codegen CLI - Transform your code with AI.""" diff --git a/src/codegen/cli/commands/create/main.py b/src/codegen/cli/commands/create/main.py index 390276a2d..8763a23d8 100644 --- a/src/codegen/cli/commands/create/main.py +++ b/src/codegen/cli/commands/create/main.py @@ -70,7 +70,7 @@ def make_relative(path: Path) -> str: @click.argument("path", type=click.Path(path_type=Path), default=Path.cwd()) @click.option("--description", "-d", default=None, help="Description of what this codemod does.") @click.option("--overwrite", is_flag=True, help="Overwrites function if it already exists.") -def create_command(session: CodegenSession, name: str, path: Path, description: str | None = None, overwrite: bool = False): +def create_command(session: CodegenSession, name: str, path: Path, description: str | None = None, overwrite: bool = False) -> None: """Create a new codegen function. NAME is the name/label for the function diff --git a/src/codegen/cli/commands/deploy/main.py b/src/codegen/cli/commands/deploy/main.py index 8963a3385..66bccc4ff 100644 --- a/src/codegen/cli/commands/deploy/main.py +++ b/src/codegen/cli/commands/deploy/main.py @@ -47,7 +47,7 @@ def deploy_functions(session: CodegenSession, functions: list[DecoratedFunction] @click.argument("name", required=False) @click.option("-d", "--directory", type=click.Path(exists=True, path_type=Path), help="Directory to search for functions") @click.option("-m", "--message", help="Optional message to include with the deploy") -def deploy_command(session: CodegenSession, name: str | None = None, directory: Path | None = None, message: str | None = None): +def deploy_command(session: CodegenSession, name: str | None = None, directory: Path | None = None, message: str | None = None) -> None: """Deploy codegen functions. If NAME is provided, deploys a specific function by that name. diff --git a/src/codegen/cli/commands/expert/main.py b/src/codegen/cli/commands/expert/main.py index d8e55d9d2..0fc30198e 100644 --- a/src/codegen/cli/commands/expert/main.py +++ b/src/codegen/cli/commands/expert/main.py @@ -13,7 +13,7 @@ @click.option("--query", "-q", help="The question to ask the expert.") @requires_auth @requires_init -def expert_command(session: CodegenSession, query: str): +def expert_command(session: CodegenSession, query: str) -> None: """Asks a codegen expert a question.""" status = Status("Asking expert...", spinner="dots", spinner_style="purple") status.start() diff --git a/src/codegen/cli/commands/init/main.py b/src/codegen/cli/commands/init/main.py index 06cedf762..5c22806b6 100644 --- a/src/codegen/cli/commands/init/main.py +++ b/src/codegen/cli/commands/init/main.py @@ -17,7 +17,7 @@ @click.option("--repo-name", type=str, help="The name of the repository") @click.option("--organization-name", type=str, help="The name of the organization") @click.option("--fetch-docs", is_flag=True, help="Fetch docs and examples (requires auth)") -def init_command(repo_name: str | None = None, organization_name: str | None = None, fetch_docs: bool = False): +def init_command(repo_name: str | None = None, organization_name: str | None = None, fetch_docs: bool = False) -> None: """Initialize or update the Codegen folder.""" # Print a message if not in a git repo try: diff --git a/src/codegen/cli/commands/list/main.py b/src/codegen/cli/commands/list/main.py index 046c1a17e..1257ac261 100644 --- a/src/codegen/cli/commands/list/main.py +++ b/src/codegen/cli/commands/list/main.py @@ -9,7 +9,7 @@ @click.command(name="list") -def list_command(): +def list_command() -> None: """List available codegen functions.""" functions = CodemodManager.get_decorated() if functions: diff --git a/src/codegen/cli/commands/login/main.py b/src/codegen/cli/commands/login/main.py index 67c7d46db..41615bc5e 100644 --- a/src/codegen/cli/commands/login/main.py +++ b/src/codegen/cli/commands/login/main.py @@ -8,7 +8,7 @@ @click.command(name="login") @click.option("--token", required=False, help="API token for authentication") -def login_command(token: str): +def login_command(token: str) -> None: """Store authentication token.""" # Check if already authenticated token_manager = TokenManager() diff --git a/src/codegen/cli/commands/logout/main.py b/src/codegen/cli/commands/logout/main.py index 93216bea0..a669cbdf2 100644 --- a/src/codegen/cli/commands/logout/main.py +++ b/src/codegen/cli/commands/logout/main.py @@ -5,7 +5,7 @@ @click.command(name="logout") -def logout_command(): +def logout_command() -> None: """Clear stored authentication token.""" token_manager = TokenManager() token_manager.clear_token() diff --git a/src/codegen/cli/commands/notebook/main.py b/src/codegen/cli/commands/notebook/main.py index d00447288..271146aa2 100644 --- a/src/codegen/cli/commands/notebook/main.py +++ b/src/codegen/cli/commands/notebook/main.py @@ -23,7 +23,7 @@ def create_jupyter_dir() -> Path: @click.option("--background", is_flag=True, help="Run Jupyter Lab in the background") @click.option("--demo", is_flag=True, help="Create a demo notebook with FastAPI example code") @requires_init -def notebook_command(session: CodegenSession, background: bool, demo: bool): +def notebook_command(session: CodegenSession, background: bool, demo: bool) -> None: """Launch Jupyter Lab with a pre-configured notebook for exploring your codebase.""" with create_spinner("Setting up Jupyter environment...") as status: venv = VenvManager() diff --git a/src/codegen/cli/commands/profile/main.py b/src/codegen/cli/commands/profile/main.py index 856b68b7a..cb4c8b782 100644 --- a/src/codegen/cli/commands/profile/main.py +++ b/src/codegen/cli/commands/profile/main.py @@ -11,7 +11,7 @@ @click.command(name="profile") @requires_auth @requires_init -def profile_command(session: CodegenSession): +def profile_command(session: CodegenSession) -> None: """Display information about the currently authenticated user.""" rich.print( Panel( diff --git a/src/codegen/cli/commands/run/main.py b/src/codegen/cli/commands/run/main.py index 13e9cf843..01c79bca2 100644 --- a/src/codegen/cli/commands/run/main.py +++ b/src/codegen/cli/commands/run/main.py @@ -22,7 +22,7 @@ def run_command( web: bool = False, diff_preview: int | None = None, arguments: str | None = None, -): +) -> None: """Run a codegen function by its label.""" # Ensure venv is initialized venv = VenvManager() diff --git a/src/codegen/cli/commands/run/render.py b/src/codegen/cli/commands/run/render.py index fd60ffa6e..207e96af8 100644 --- a/src/codegen/cli/commands/run/render.py +++ b/src/codegen/cli/commands/run/render.py @@ -6,7 +6,7 @@ from codegen.cli.api.schemas import RunCodemodOutput -def pretty_print_output(output: RunCodemodOutput): +def pretty_print_output(output: RunCodemodOutput) -> None: """Pretty print the codemod run output with panels.""" if output.web_link: rich.print("\n• [blue underline]" + output.web_link + "[/blue underline]\n") @@ -21,7 +21,7 @@ def pretty_print_output(output: RunCodemodOutput): pretty_print_diff(output.observation) -def pretty_print_logs(logs: str): +def pretty_print_logs(logs: str) -> None: """Pretty print logs in a panel.""" rich.print( Panel( @@ -35,7 +35,7 @@ def pretty_print_logs(logs: str): rich.print() # spacing -def pretty_print_error(error: str): +def pretty_print_error(error: str) -> None: """Pretty print error in a panel.""" rich.print( Panel( @@ -49,7 +49,7 @@ def pretty_print_error(error: str): rich.print() # spacing -def pretty_print_diff(diff: str): +def pretty_print_diff(diff: str) -> None: """Pretty print diff in a panel.""" rich.print( Panel( diff --git a/src/codegen/cli/commands/run/run_cloud.py b/src/codegen/cli/commands/run/run_cloud.py index e066b3967..adffa341b 100644 --- a/src/codegen/cli/commands/run/run_cloud.py +++ b/src/codegen/cli/commands/run/run_cloud.py @@ -13,7 +13,7 @@ from codegen.cli.utils.url import generate_webapp_url -def run_cloud(session: CodegenSession, function, apply_local: bool = False, diff_preview: int | None = None): +def run_cloud(session: CodegenSession, function, apply_local: bool = False, diff_preview: int | None = None) -> None: """Run a function on the cloud service. Args: diff --git a/src/codegen/cli/commands/run_on_pr/main.py b/src/codegen/cli/commands/run_on_pr/main.py index d3b28c471..34d6f8ddb 100644 --- a/src/codegen/cli/commands/run_on_pr/main.py +++ b/src/codegen/cli/commands/run_on_pr/main.py @@ -36,7 +36,7 @@ def run_on_pr(session: CodegenSession, codemod_name: str, pr_number: int) -> Non @requires_auth @click.argument("codemod_name", type=str) @click.argument("pr_number", type=int) -def run_on_pr_command(session: CodegenSession, codemod_name: str, pr_number: int): +def run_on_pr_command(session: CodegenSession, codemod_name: str, pr_number: int) -> None: """Test a webhook against a specific PR. CODEMOD_NAME is the name of the codemod to test diff --git a/src/codegen/cli/commands/style_debug/main.py b/src/codegen/cli/commands/style_debug/main.py index 642c16f86..1b99ef8f6 100644 --- a/src/codegen/cli/commands/style_debug/main.py +++ b/src/codegen/cli/commands/style_debug/main.py @@ -9,7 +9,7 @@ @click.command(name="style-debug") @click.option("--text", default="Loading...", help="Text to show in the spinner") -def style_debug_command(text: str): +def style_debug_command(text: str) -> None: """Debug command to visualize CLI styling (spinners, etc).""" try: with create_spinner(text) as status: diff --git a/src/codegen/cli/errors.py b/src/codegen/cli/errors.py index f2a9061cb..1c8f33f1c 100644 --- a/src/codegen/cli/errors.py +++ b/src/codegen/cli/errors.py @@ -36,7 +36,7 @@ class ServerError(CodegenError): pass -def format_error_message(error): +def format_error_message(error) -> str: """Format error message based on error type.""" if isinstance(error, AuthError): return "[red]Authentication Error:[/red] Please run 'codegen login' first." diff --git a/src/codegen/cli/git/patch.py b/src/codegen/cli/git/patch.py index 1dbb991e7..99efea595 100644 --- a/src/codegen/cli/git/patch.py +++ b/src/codegen/cli/git/patch.py @@ -4,7 +4,7 @@ from rich.status import Status -def apply_patch(git_repo: Repository, patch: str): +def apply_patch(git_repo: Repository, patch: str) -> None: """Apply a git patch to the repository. Args: diff --git a/src/codegen/cli/rich/pretty_print.py b/src/codegen/cli/rich/pretty_print.py index fd60ffa6e..207e96af8 100644 --- a/src/codegen/cli/rich/pretty_print.py +++ b/src/codegen/cli/rich/pretty_print.py @@ -6,7 +6,7 @@ from codegen.cli.api.schemas import RunCodemodOutput -def pretty_print_output(output: RunCodemodOutput): +def pretty_print_output(output: RunCodemodOutput) -> None: """Pretty print the codemod run output with panels.""" if output.web_link: rich.print("\n• [blue underline]" + output.web_link + "[/blue underline]\n") @@ -21,7 +21,7 @@ def pretty_print_output(output: RunCodemodOutput): pretty_print_diff(output.observation) -def pretty_print_logs(logs: str): +def pretty_print_logs(logs: str) -> None: """Pretty print logs in a panel.""" rich.print( Panel( @@ -35,7 +35,7 @@ def pretty_print_logs(logs: str): rich.print() # spacing -def pretty_print_error(error: str): +def pretty_print_error(error: str) -> None: """Pretty print error in a panel.""" rich.print( Panel( @@ -49,7 +49,7 @@ def pretty_print_error(error: str): rich.print() # spacing -def pretty_print_diff(diff: str): +def pretty_print_diff(diff: str) -> None: """Pretty print diff in a panel.""" rich.print( Panel( diff --git a/src/codegen/cli/sdk/decorator.py b/src/codegen/cli/sdk/decorator.py index e4e93e956..30386caaa 100644 --- a/src/codegen/cli/sdk/decorator.py +++ b/src/codegen/cli/sdk/decorator.py @@ -18,7 +18,7 @@ def __init__( webhook_config: dict | None = None, lint_mode: bool = False, lint_user_whitelist: Sequence[str] | None = None, - ): + ) -> None: self.name = name self.func: Callable | None = None self.params_type = None diff --git a/src/codegen/cli/sdk/pull_request.py b/src/codegen/cli/sdk/pull_request.py index 7caca0ada..79c4006cc 100644 --- a/src/codegen/cli/sdk/pull_request.py +++ b/src/codegen/cli/sdk/pull_request.py @@ -14,7 +14,7 @@ class CodegenPullRequest: codegen_pr_id: int patch_set: PatchSet | None = None - def __init__(self, url: str, number: int, title: str, github_pr_number: int, codegen_pr_id: int, patch_set: PatchSet | None = None): + def __init__(self, url: str, number: int, title: str, github_pr_number: int, codegen_pr_id: int, patch_set: PatchSet | None = None) -> None: self.url = url self.number = number self.title = title @@ -47,5 +47,5 @@ def lookup(cls, number: int) -> "CodegenPullRequest": patch_set=None, # Can be loaded on demand if needed ) - def __str__(self): + def __str__(self) -> str: return f"CodegenPullRequest(url={self.url}, number={self.number}, title={self.title}, github_pr_number={self.github_pr_number}, codegen_pr_id={self.codegen_pr_id})" diff --git a/src/codegen/cli/utils/count_functions.py b/src/codegen/cli/utils/count_functions.py index 014c28a60..39e99f086 100644 --- a/src/codegen/cli/utils/count_functions.py +++ b/src/codegen/cli/utils/count_functions.py @@ -20,7 +20,7 @@ class CountFunctionsArgs(BaseModel): @codegen.cli.sdk.decorator.function("count-functions") -def run(codebase, pr_options, arguments: CountFunctionsArgs): +def run(codebase, pr_options, arguments: CountFunctionsArgs) -> None: # Count Functions in Codebase # Initialize a total function counter diff --git a/src/codegen/cli/utils/function_finder.py b/src/codegen/cli/utils/function_finder.py index 0d77a71fd..308112078 100644 --- a/src/codegen/cli/utils/function_finder.py +++ b/src/codegen/cli/utils/function_finder.py @@ -80,7 +80,7 @@ def validate(self) -> None: class CodegenFunctionVisitor(ast.NodeVisitor): - def __init__(self): + def __init__(self) -> None: self.functions: list[DecoratedFunction] = [] def get_function_body(self, node: ast.FunctionDef) -> str: @@ -174,7 +174,7 @@ def get_function_parameters(self, node: ast.FunctionDef) -> list[tuple[str, str return parameters - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node) -> None: for decorator in node.decorator_list: if ( isinstance(decorator, ast.Call) @@ -219,7 +219,7 @@ def _get_decorator_attrs(self, node): node = node.value return attrs - def visit_Module(self, node): + def visit_Module(self, node) -> None: # Store the full source code for later use self.source = self.file_content self.generic_visit(node) diff --git a/src/codegen/cli/workspace/docs_workspace.py b/src/codegen/cli/workspace/docs_workspace.py index f12437857..9af379a4d 100644 --- a/src/codegen/cli/workspace/docs_workspace.py +++ b/src/codegen/cli/workspace/docs_workspace.py @@ -4,7 +4,7 @@ from rich.status import Status -def populate_api_docs(dest: Path, api_docs: dict[str, str], status: Status): +def populate_api_docs(dest: Path, api_docs: dict[str, str], status: Status) -> None: """Writes all API docs to the docs folder""" status.update("Populating API documentation...") # Remove existing docs diff --git a/src/codegen/cli/workspace/examples_workspace.py b/src/codegen/cli/workspace/examples_workspace.py index 54c6bfbd3..d9991e132 100644 --- a/src/codegen/cli/workspace/examples_workspace.py +++ b/src/codegen/cli/workspace/examples_workspace.py @@ -8,7 +8,7 @@ from codegen.cli.codemod.convert import convert_to_cli -def populate_examples(session: CodegenSession, dest: Path, examples: list[SerializedExample], status: Status): +def populate_examples(session: CodegenSession, dest: Path, examples: list[SerializedExample], status: Status) -> None: """Populate the examples folder with examples for the current repository.""" status.update("Populating example codemods...") # Remove existing examples diff --git a/src/codegen/cli/workspace/initialize_workspace.py b/src/codegen/cli/workspace/initialize_workspace.py index e57d7ee50..89eee6d6c 100644 --- a/src/codegen/cli/workspace/initialize_workspace.py +++ b/src/codegen/cli/workspace/initialize_workspace.py @@ -117,14 +117,14 @@ def initialize_codegen( return CODEGEN_FOLDER, DOCS_FOLDER, EXAMPLES_FOLDER -def add_to_gitignore_if_not_present(gitignore: Path, line: str): +def add_to_gitignore_if_not_present(gitignore: Path, line: str) -> None: if not gitignore.exists(): gitignore.write_text(line) elif line not in gitignore.read_text(): gitignore.write_text(gitignore.read_text() + "\n" + line) -def modify_gitignore(codegen_folder: Path): +def modify_gitignore(codegen_folder: Path) -> None: """Update .gitignore to track only specific Codegen files.""" gitignore_path = codegen_folder / ".gitignore" diff --git a/src/codegen/cli/workspace/venv_manager.py b/src/codegen/cli/workspace/venv_manager.py index 6dc1bd86e..cc9b8d7ab 100644 --- a/src/codegen/cli/workspace/venv_manager.py +++ b/src/codegen/cli/workspace/venv_manager.py @@ -8,7 +8,7 @@ class VenvManager: """Manages the virtual environment for codegen.""" - def __init__(self): + def __init__(self) -> None: self.codegen_dir = Path.cwd() / CODEGEN_DIR self.venv_dir = self.codegen_dir / ".venv" @@ -17,7 +17,7 @@ def is_initialized(self) -> bool: python_path = self.venv_dir / "bin" / "python" return self.venv_dir.exists() and python_path.exists() - def create_venv(self, python_version: str = "3.13"): + def create_venv(self, python_version: str = "3.13") -> None: """Create a virtual environment using uv.""" self.codegen_dir.mkdir(parents=True, exist_ok=True) subprocess.run( @@ -25,7 +25,7 @@ def create_venv(self, python_version: str = "3.13"): check=True, ) - def install_packages(self, *packages: str): + def install_packages(self, *packages: str) -> None: """Install packages into the virtual environment using uv pip.""" subprocess.run( ["uv", "pip", "install", *packages], @@ -33,7 +33,7 @@ def install_packages(self, *packages: str): env={**os.environ, "VIRTUAL_ENV": str(self.venv_dir)}, ) - def run_python(self, script: str, *args: str): + def run_python(self, script: str, *args: str) -> None: """Run a Python script in the virtual environment.""" python_path = self.venv_dir / "bin" / "python" subprocess.run([str(python_path), "-c", script, *args], check=True) @@ -46,7 +46,7 @@ def is_active(self) -> bool: """Check if a virtual environment is active.""" return "VIRTUAL_ENV" in os.environ - def ensure_jupyter(self): + def ensure_jupyter(self) -> None: """Ensure Jupyter Lab is installed in the virtual environment.""" try: subprocess.run( diff --git a/src/codegen/git/clients/git_repo_client.py b/src/codegen/git/clients/git_repo_client.py index aab02affe..5ac70bf88 100644 --- a/src/codegen/git/clients/git_repo_client.py +++ b/src/codegen/git/clients/git_repo_client.py @@ -417,7 +417,7 @@ def get_workflow_safe(self, file_name: str) -> Workflow | None: logger.warning(f"Error getting workflow by file name: {file_name}\n\t{e}") return None - def create_workflow_dispatch(self, workflow: Workflow, ref: Branch | Tag | Commit | str, inputs: Opt[dict] = NotSet): + def create_workflow_dispatch(self, workflow: Workflow, ref: Branch | Tag | Commit | str, inputs: Opt[dict] = NotSet) -> None: writeable_workflow = self._write_client.get_workflow(workflow.id) writeable_workflow.create_dispatch(ref=ref, inputs=inputs) diff --git a/src/codegen/gscli/backend/typestub_utils.py b/src/codegen/gscli/backend/typestub_utils.py index 49ff8f909..659a24f73 100644 --- a/src/codegen/gscli/backend/typestub_utils.py +++ b/src/codegen/gscli/backend/typestub_utils.py @@ -11,7 +11,7 @@ class MethodRemover(ast.NodeTransformer): - def __init__(self, conditions: list[Callable[[ast.FunctionDef], bool]]): + def __init__(self, conditions: list[Callable[[ast.FunctionDef], bool]]) -> None: self.conditions = conditions def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: @@ -43,7 +43,7 @@ def should_remove(self, node: ast.FunctionDef | ast.AnnAssign) -> bool: class FieldRemover(ast.NodeTransformer): - def __init__(self, conditions: list[Callable[[ast.FunctionDef], bool]]): + def __init__(self, conditions: list[Callable[[ast.FunctionDef], bool]]) -> None: self.conditions = conditions def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: diff --git a/src/codegen/gscli/generate/runner_imports.py b/src/codegen/gscli/generate/runner_imports.py index f06c2301a..b30189af5 100644 --- a/src/codegen/gscli/generate/runner_imports.py +++ b/src/codegen/gscli/generate/runner_imports.py @@ -51,7 +51,7 @@ def get_generated_imports(): ) -def fix_ruff_imports(objects: list[DocumentedObject]): +def fix_ruff_imports(objects: list[DocumentedObject]) -> None: root, _ = split_git_path(str(Path(__file__))) to_add = [] for obj in objects: diff --git a/src/codegen/gscli/generate/utils.py b/src/codegen/gscli/generate/utils.py index d579f9288..6336fdd63 100644 --- a/src/codegen/gscli/generate/utils.py +++ b/src/codegen/gscli/generate/utils.py @@ -12,7 +12,7 @@ class LanguageType(StrEnum): BOTH = "BOTH" -def generate_builtins_file(path_to_builtins: str, language_type: LanguageType): +def generate_builtins_file(path_to_builtins: str, language_type: LanguageType) -> None: """Generates and writes the builtins file""" documented_imports = get_documented_objects() all_objects = chain(documented_imports["apidoc"], documented_imports["py_apidoc"], documented_imports["ts_apidoc"]) diff --git a/src/codegen/runner/sandbox/executor.py b/src/codegen/runner/sandbox/executor.py index b47d0e4ad..e0f913556 100644 --- a/src/codegen/runner/sandbox/executor.py +++ b/src/codegen/runner/sandbox/executor.py @@ -28,7 +28,7 @@ class SandboxExecutor: codebase: CodebaseType remote_repo: SandboxRepo - def __init__(self, codebase: CodebaseType): + def __init__(self, codebase: CodebaseType) -> None: self.codebase = codebase self.remote_repo = SandboxRepo(self.codebase) diff --git a/src/codegen/runner/sandbox/middlewares.py b/src/codegen/runner/sandbox/middlewares.py index 96e0d6b85..594087f13 100644 --- a/src/codegen/runner/sandbox/middlewares.py +++ b/src/codegen/runner/sandbox/middlewares.py @@ -65,7 +65,7 @@ async def process_request(self, request: TRequest, call_next: RequestResponseEnd res.background = background_tasks return res - async def cleanup_after_codemod(self, is_exception: bool = False): + async def cleanup_after_codemod(self, is_exception: bool = False) -> None: if is_exception: # TODO: instead of committing transactions, we should just rollback logger.info("Committing pending transactions due to exception") diff --git a/src/codegen/runner/sandbox/repo.py b/src/codegen/runner/sandbox/repo.py index b3d06c37f..8afa0d7a9 100644 --- a/src/codegen/runner/sandbox/repo.py +++ b/src/codegen/runner/sandbox/repo.py @@ -29,7 +29,7 @@ def set_up_base_branch(self, base_branch: str | None) -> None: # checkout the base branch (and possibly sync graph) self.codebase.checkout(branch=base_branch) - def set_up_head_branch(self, head_branch: str, force_push_head_branch: bool): + def set_up_head_branch(self, head_branch: str, force_push_head_branch: bool) -> None: """Set-up head branch by pushing latest highside branch to lowside and fetching the branch (so that it can be checked out later).""" # If head branch is not specified, do nothing if head_branch is None: diff --git a/src/codegen/sdk/ai/converters.py b/src/codegen/sdk/ai/converters.py index 4165d8bd4..bd5be97b4 100644 --- a/src/codegen/sdk/ai/converters.py +++ b/src/codegen/sdk/ai/converters.py @@ -1,4 +1,6 @@ # TODO: these should move to claude +from __future__ import annotations + import json import logging @@ -39,7 +41,7 @@ def convert_messages(messages: list) -> list: elif isinstance(message, dict): new_message = convert_dict_message(message) else: - logger.warn(f"Message type not supported: {type(message)}") + logger.warning(f"Message type not supported: {type(message)}") continue new_messages.append(new_message) return new_messages @@ -66,15 +68,14 @@ def convert_openai_message(message): def convert_dict_message(message): """Converts a dictionary message to an Anthropic message.""" role = message.get("role") - if role == "user" or role == "assistant": + if role in ("user", "assistant"): if "tool_calls" in message: return create_anthropic_message_with_tool_calls(role, message["content"], message["tool_calls"]) return {"role": role, "content": message["content"]} - elif role == "tool": + if role == "tool": return {"role": "user", "content": [{"type": "tool_result", "tool_use_id": message["tool_call_id"], "content": message["content"]}]} - else: - logger.warn(f"Unsupported role: {role}") - return None + logger.warning(f"Unsupported role: {role}") + return None def create_anthropic_message_with_tool_calls(role, content, tool_calls): @@ -89,12 +90,12 @@ def create_anthropic_message_with_tool_calls(role, content, tool_calls): "id": tool_call.id if hasattr(tool_call, "id") else tool_call["id"], "name": tool_call.function.name if hasattr(tool_call, "function") else tool_call["function"]["name"], "input": json.loads(tool_call.function.arguments if hasattr(tool_call, "function") else tool_call["function"]["arguments"]), - } + }, ) return {"role": role, "content": message_content} -def merge_user_messages(existing_message, new_message): +def merge_user_messages(existing_message, new_message) -> None: """Merges two user messages into a single message.""" if isinstance(existing_message["content"], list): existing_message["content"].extend(new_message["content"] if isinstance(new_message["content"], list) else [{"type": "text", "text": new_message["content"]}]) diff --git a/src/codegen/sdk/ai/helpers.py b/src/codegen/sdk/ai/helpers.py index a5fe6b3bd..dfc6872ec 100644 --- a/src/codegen/sdk/ai/helpers.py +++ b/src/codegen/sdk/ai/helpers.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import json import logging from abc import ABC, abstractmethod +from typing import NoReturn import anthropic import anthropic.types as anthropic_types @@ -33,10 +36,10 @@ def count_tokens(s: str, model_name: str = "gpt-4-32k") -> int: - """Uses tiktoken""" + """Uses tiktoken.""" if s is None: return 0 - enc = ENCODERS.get(model_name, None) + enc = ENCODERS.get(model_name) if not enc: ENCODERS[model_name] = tiktoken.encoding_for_model(model_name) enc = ENCODERS[model_name] @@ -64,7 +67,7 @@ def __init__(self) -> None: raise NotImplementedError(msg) @abstractmethod - def embeddings_with_backoff(self, **kwargs): + def embeddings_with_backoff(self, **kwargs) -> NoReturn: msg = "This is an abstract class" raise NotImplementedError(msg) @@ -79,27 +82,29 @@ def get_embedding(self, content_str: str) -> list[float]: raise NotImplementedError(msg) @abstractmethod - def llm_query_with_retry(self, **kwargs): + def llm_query_with_retry(self, **kwargs) -> NoReturn: msg = "This is an abstract class" raise NotImplementedError(msg) @abstractmethod - def llm_query_no_retry(self, messages: list = [], model: str = "gpt-4-32k", max_tokens: int | None = None): + def llm_query_no_retry(self, messages: list | None = None, model: str = "gpt-4-32k", max_tokens: int | None = None) -> NoReturn: + if messages is None: + messages = [] msg = "This is an abstract class" raise NotImplementedError(msg) @abstractmethod - def llm_query_functions_with_retry(self, model: str, messages: list, functions: list[dict], max_tokens: int | None = None): + def llm_query_functions_with_retry(self, model: str, messages: list, functions: list[dict], max_tokens: int | None = None) -> NoReturn: msg = "This is an abstract class" raise NotImplementedError(msg) @abstractmethod - def llm_query_functions(self, model: str, messages: list, functions: list[dict], max_tokens: int | None = None): + def llm_query_functions(self, model: str, messages: list, functions: list[dict], max_tokens: int | None = None) -> NoReturn: msg = "This is an abstract class" raise NotImplementedError(msg) @abstractmethod - def llm_response_to_json(response) -> str: + def llm_response_to_json(self) -> str: msg = "This is an abstract class" raise NotImplementedError(msg) @@ -151,7 +156,9 @@ def completions_with_backoff(self, **kwargs): def llm_query_with_retry(self, **kwargs): return self.llm_query_no_retry(**kwargs) - def llm_query_no_retry(self, messages: list = [], model: str = "gpt-4-32k", max_tokens: int | None = None, **kwargs): + def llm_query_no_retry(self, messages: list | None = None, model: str = "gpt-4-32k", max_tokens: int | None = None, **kwargs): + if messages is None: + messages = [] return self.client.chat.completions.create( messages=messages, model=model, @@ -227,7 +234,10 @@ def _convert_openai_functions_to_claude(self, functions: list[dict]) -> list[ant return new_functions def _convert_claude_response_to_openai( - self, response: anthropic_types.Message | anthropic_tool_types.ToolsBetaMessage, parse_function_calls: bool = False, parse_result_block: bool = False + self, + response: anthropic_types.Message | anthropic_tool_types.ToolsBetaMessage, + parse_function_calls: bool = False, + parse_result_block: bool = False, ) -> openai_types.chat_completion.ChatCompletion: choices = [] if len(response.content) != 0: @@ -244,7 +254,7 @@ def _convert_claude_response_to_openai( index=0, finish_reason="stop" if response.stop_reason in ("end_turn", "stop_sequence") else "length", message=openai_types.chat_completion_message.ChatCompletionMessage(content=resp.text, role="assistant"), - ) + ), ) elif isinstance(resp, anthropic_tool_types.ToolUseBlock): # If the previous choice is a chat message, then we can add the tool call to it @@ -269,10 +279,10 @@ def _convert_claude_response_to_openai( arguments=json.dumps(resp.input), ), type="function", - ) + ), ], ), - ) + ), ) return openai_types.chat_completion.ChatCompletion( id=response.id, @@ -284,7 +294,7 @@ def _convert_claude_response_to_openai( ) @backoff.on_exception(backoff.expo, anthropic.RateLimitError) - def embeddings_with_backoff(self, **kwargs): + def embeddings_with_backoff(self, **kwargs) -> NoReturn: msg = "Embeddings are not supported for AnthropicHelper" raise NotImplementedError(msg) # response = self.client.embeddings.create(**kwargs) @@ -304,7 +314,7 @@ def get_embedding(self, content_str: str) -> list[float]: # return embeddings[0] @backoff.on_exception(backoff.expo, anthropic.RateLimitError) - def completions_with_backoff(self, **kwargs): + def completions_with_backoff(self, **kwargs) -> NoReturn: msg = "Claude's completion api is deprecated. Please use messages_with_backoff instead." raise Exception(msg) @@ -316,8 +326,10 @@ def messages_with_backoff(self, **kwargs): def llm_query_with_retry(self, **kwargs): return self.llm_query_no_retry(**kwargs) - def llm_query_no_retry(self, messages: list = [], model: str = "claude-2.0", max_tokens: int | None = None, system_prompt: str | anthropic.NotGiven | None = None, **kwargs): - system_prompt = anthropic.NotGiven() if not system_prompt else system_prompt + def llm_query_no_retry(self, messages: list | None = None, model: str = "claude-2.0", max_tokens: int | None = None, system_prompt: str | anthropic.NotGiven | None = None, **kwargs): + if messages is None: + messages = [] + system_prompt = system_prompt if system_prompt else anthropic.NotGiven() if self.openai_anthropic_translation and model in CLAUDE_OPENAI_MODEL_MAP: model = CLAUDE_OPENAI_MODEL_MAP[model] if self.openai_anthropic_translation: @@ -329,15 +341,14 @@ def llm_query_no_retry(self, messages: list = [], model: str = "claude-2.0", max response = self.client.beta.tools.messages.create(max_tokens=max_tokens, system=claude_system_prompt, messages=messages, model=model, **kwargs) if self.openai_anthropic_translation: return self._convert_claude_response_to_openai(response) - else: - return response + return response @retry(wait=wait_random_exponential(min=70, max=600), stop=stop_after_attempt(10)) def llm_query_functions_with_retry(self, **kwargs): return self.llm_query_functions(**kwargs) def llm_query_functions(self, model: str, messages: list, functions: list, max_tokens: int | None = None, system_prompt: str | anthropic.NotGiven | None = None, **kwargs): - system_prompt = anthropic.NotGiven() if not system_prompt else system_prompt + system_prompt = system_prompt if system_prompt else anthropic.NotGiven() if self.openai_anthropic_translation and model in CLAUDE_OPENAI_MODEL_MAP: model = CLAUDE_OPENAI_MODEL_MAP[model] if functions is not None: @@ -359,27 +370,23 @@ def llm_query_functions(self, model: str, messages: list, functions: list, max_t ) if self.openai_anthropic_translation: return self._convert_claude_response_to_openai(response, parse_function_calls=True, parse_result_block=True) - else: - return response - else: - response = self.llm_query_no_retry( - model=model, - messages=messages, - max_tokens=max_tokens or anthropic.NotGiven(), - system_prompt=system_prompt, - **kwargs, - ) - return response + return response + return self.llm_query_no_retry( + model=model, + messages=messages, + max_tokens=max_tokens or anthropic.NotGiven(), + system_prompt=system_prompt, + **kwargs, + ) @staticmethod def llm_response_to_json(response: openai_types.chat_completion.ChatCompletion | anthropic_types.Message) -> str: if isinstance(response, openai_types.chat_completion.ChatCompletion): return OpenAIHelper.llm_response_to_json(response) - else: - js = json.loads(response.model_dump_json()) - if len(js["content"]) == 0: - return "" - return js["content"][0]["text"] + js = json.loads(response.model_dump_json()) + if len(js["content"]) == 0: + return "" + return js["content"][0]["text"] class MultiProviderAIHelper(AbstractAIHelper): @@ -411,62 +418,58 @@ def embeddings_with_backoff(self, **kwargs): # Prioritize OpenAI First if self.use_openai: return self.openai_helper.embeddings_with_backoff(**kwargs) - elif self.use_claude: + if self.use_claude: return self.anthropic_helper.embeddings_with_backoff(**kwargs) - else: - msg = "MultiProviderAIHelper: No AI helper is enabled" - raise Exception(msg) + msg = "MultiProviderAIHelper: No AI helper is enabled" + raise Exception(msg) def get_embeddings(self, content_strs: list[str]) -> list[list[float]]: # Prioritize OpenAI First if self.use_openai: return self.openai_helper.get_embeddings(content_strs) - elif self.use_claude: + if self.use_claude: return self.anthropic_helper.get_embeddings(content_strs) - else: - msg = "MultiProviderAIHelper: No AI helper is enabled" - raise Exception(msg) + msg = "MultiProviderAIHelper: No AI helper is enabled" + raise Exception(msg) def get_embedding(self, content_str: str) -> list[float]: # Prioritize OpenAI First if self.use_openai: return self.openai_helper.get_embedding(content_str) - elif self.use_claude: + if self.use_claude: return self.anthropic_helper.get_embedding(content_str) - else: - msg = "MultiProviderAIHelper: No AI helper is enabled" - raise Exception(msg) + msg = "MultiProviderAIHelper: No AI helper is enabled" + raise Exception(msg) @backoff.on_exception(backoff.expo, anthropic.RateLimitError) def completions_with_backoff(self, **kwargs): # This is OpenAI specific if self.use_openai: return self.openai_helper.completions_with_backoff(**kwargs) - else: - msg = "MultiProviderAIHelper: OpenAI Helper is not enabled" - raise Exception(msg) + msg = "MultiProviderAIHelper: OpenAI Helper is not enabled" + raise Exception(msg) @backoff.on_exception(backoff.expo, anthropic.RateLimitError) def messages_with_backoff(self, **kwargs): # This is Anthropic specific if self.use_claude: return self.anthropic_helper.messages_with_backoff(**kwargs) - else: - msg = "MultiProviderAIHelper: Anthropic Helper is not enabled" - raise Exception(msg) + msg = "MultiProviderAIHelper: Anthropic Helper is not enabled" + raise Exception(msg) @retry(wait=wait_random_exponential(min=70, max=600), stop=stop_after_attempt(10)) def llm_query_with_retry(self, **kwargs): return self.llm_query_no_retry(**kwargs) - def llm_query_no_retry(self, messages: list = [], model: str = "gpt-4-32k", max_tokens: int | None = None, **kwargs): + def llm_query_no_retry(self, messages: list | None = None, model: str = "gpt-4-32k", max_tokens: int | None = None, **kwargs): + if messages is None: + messages = [] if self.use_openai and model.startswith("gpt"): return self.openai_helper.llm_query_no_retry(messages=messages, model=model, max_tokens=max_tokens, **kwargs) - elif self.use_claude and model.startswith("claude"): + if self.use_claude and model.startswith("claude"): return self.anthropic_helper.llm_query_no_retry(messages=messages, model=model, max_tokens=max_tokens, **kwargs) - else: - msg = f"MultiProviderAIHelper: Unknown Model {model}" - raise Exception(msg) + msg = f"MultiProviderAIHelper: Unknown Model {model}" + raise Exception(msg) @retry(wait=wait_random_exponential(min=70, max=600), stop=stop_after_attempt(10)) def llm_query_functions_with_retry(self, **kwargs): @@ -475,11 +478,10 @@ def llm_query_functions_with_retry(self, **kwargs): def llm_query_functions(self, model: str, messages: list, functions: list[dict], max_tokens: int | None = None, **kwargs): if self.use_openai and model.startswith("gpt"): return self.openai_helper.llm_query_functions(model, messages, functions, max_tokens, **kwargs) - elif self.use_claude and model.startswith("claude"): + if self.use_claude and model.startswith("claude"): return self.anthropic_helper.llm_query_functions(model, messages, functions, max_tokens, **kwargs) - else: - msg = f"MultiProviderAIHelper: Unknown Model {model}" - raise Exception(msg) + msg = f"MultiProviderAIHelper: Unknown Model {model}" + raise Exception(msg) @staticmethod def llm_response_to_json(response) -> str: diff --git a/src/codegen/sdk/code_generation/current_code_codebase.py b/src/codegen/sdk/code_generation/current_code_codebase.py index 1a483d1c0..0b59ddb4a 100644 --- a/src/codegen/sdk/code_generation/current_code_codebase.py +++ b/src/codegen/sdk/code_generation/current_code_codebase.py @@ -43,7 +43,7 @@ def get_current_code_codebase(config: CodebaseConfig = DefaultConfig, subdirecto return codebase -def import_all_codegen_sdk_modules(): +def import_all_codegen_sdk_modules() -> None: # for file in codegen.sdk: CODEGEN_SDK_DIR = Path(get_graphsitter_repo_path()) diff --git a/src/codegen/sdk/code_generation/doc_utils/utils.py b/src/codegen/sdk/code_generation/doc_utils/utils.py index d0b939de2..18c2ac140 100644 --- a/src/codegen/sdk/code_generation/doc_utils/utils.py +++ b/src/codegen/sdk/code_generation/doc_utils/utils.py @@ -115,7 +115,7 @@ def get_langauge(symbol: Class | Function | PyAttribute) -> str: return "ALL" -def get_type(method: Function): +def get_type(method: Function) -> str: """Return the type of method. Args: diff --git a/src/codegen/sdk/codebase/codebase_graph.py b/src/codegen/sdk/codebase/codebase_graph.py index 7014a2a52..102278a5e 100644 --- a/src/codegen/sdk/codebase/codebase_graph.py +++ b/src/codegen/sdk/codebase/codebase_graph.py @@ -167,7 +167,7 @@ def __init__( self.pending_files = set() self.flags = Flags() - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ @stopwatch_with_sentry(name="build_graph") @@ -487,7 +487,7 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr finally: self._computing = False - def _compute_dependencies(self, to_update: list[Importable], incremental: bool): + def _compute_dependencies(self, to_update: list[Importable], incremental: bool) -> None: seen = set() while to_update: step = to_update.copy() @@ -621,7 +621,7 @@ def out_edges(self, n: NodeId) -> WeightedEdgeList[Edge]: def remove_node(self, n: NodeId): return self._graph.remove_node(n) - def remove_edge(self, u: NodeId, v: NodeId, *, edge_type: EdgeType | None = None): + def remove_edge(self, u: NodeId, v: NodeId, *, edge_type: EdgeType | None = None) -> None: for edge in self._graph.edge_indices_from_endpoints(u, v): if edge_type is not None: if self._graph.get_edge_data_by_index(edge).type != edge_type: diff --git a/src/codegen/sdk/codebase/config_parser.py b/src/codegen/sdk/codebase/config_parser.py index 4f8ac6cc7..50a9f9499 100644 --- a/src/codegen/sdk/codebase/config_parser.py +++ b/src/codegen/sdk/codebase/config_parser.py @@ -8,7 +8,7 @@ class ConfigParser(ABC): - def __init__(self): + def __init__(self) -> None: pass @abstractmethod diff --git a/src/codegen/sdk/codebase/factory/codebase_factory.py b/src/codegen/sdk/codebase/factory/codebase_factory.py index 84e7379f1..b3ff91df9 100644 --- a/src/codegen/sdk/codebase/factory/codebase_factory.py +++ b/src/codegen/sdk/codebase/factory/codebase_factory.py @@ -1,3 +1,5 @@ +from typing import Optional + from codegen.git.repo_operator.local_repo_operator import LocalRepoOperator from codegen.git.schemas.repo_config import BaseRepoConfig from codegen.sdk.codebase.config import CodebaseConfig, ProjectConfig @@ -16,12 +18,14 @@ class CodebaseFactory: @staticmethod def get_codebase_from_files( repo_path: str = "/tmp/codegen_run_on_str", - files: dict[str, str] = {}, + files: Optional[dict[str, str]] = None, bot_commit: bool = True, repo_config: BaseRepoConfig = BaseRepoConfig(), programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, config: CodebaseConfig = CodebaseConfig(), ) -> CodebaseType: + if files is None: + files = {} op = LocalRepoOperator.create_from_files(repo_path=repo_path, files=files, bot_commit=bot_commit, repo_config=repo_config) projects = [ProjectConfig(repo_operator=op, programming_language=programming_language)] return Codebase(projects=projects, config=config) diff --git a/src/codegen/sdk/codebase/factory/get_session.py b/src/codegen/sdk/codebase/factory/get_session.py index ec4a31cce..136ed3231 100644 --- a/src/codegen/sdk/codebase/factory/get_session.py +++ b/src/codegen/sdk/codebase/factory/get_session.py @@ -1,19 +1,23 @@ +from __future__ import annotations + import os -import sys -from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager -from typing import Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, overload from codegen.git.repo_operator.local_repo_operator import LocalRepoOperator from codegen.git.schemas.repo_config import BaseRepoConfig from codegen.sdk.codebase.codebase_graph import CodebaseGraph from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags, ProjectConfig, SessionOptions, TestFlags from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory -from codegen.sdk.core.codebase import Codebase, PyCodebaseType, TSCodebaseType from codegen.sdk.enums import ProgrammingLanguage from codegen.sdk.secrets import Secrets from codegen.sdk.tree_sitter_parser import print_errors +if TYPE_CHECKING: + from collections.abc import Generator + + from codegen.sdk.core.codebase import Codebase, PyCodebaseType, TSCodebaseType + @overload def get_codebase_session( @@ -51,7 +55,7 @@ def get_codebase_session( def get_codebase_session( tmpdir: Any, programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, - files: dict[str, str] = {}, + files: dict[str, str] | None = None, commit: bool = True, sync_graph: bool = True, verify_input: bool = True, @@ -61,7 +65,9 @@ def get_codebase_session( session_options: SessionOptions = SessionOptions(), secrets: Secrets = Secrets(), ) -> Generator[Codebase, None, None]: - """Gives you a Codebase operating on the files you provided as a dict""" + """Gives you a Codebase operating on the files you provided as a dict.""" + if files is None: + files = {} config = CodebaseConfig(feature_flags=feature_flags, secrets=secrets) codebase = CodebaseFactory.get_codebase_from_files(repo_path=str(tmpdir), files=files, config=config, programming_language=programming_language, repo_config=repo_config) with codebase.session( @@ -80,7 +86,7 @@ def get_codebase_session( for file in codebase.files: if os.path.exists(file.filepath): if file.ts_node.has_error and len(file.content.splitlines()) < 10: - print(file.content, file=sys.stderr) + pass print_errors(file.filepath, file.content) assert not file.ts_node.has_error, "Invalid syntax in file after commiting" @@ -89,11 +95,13 @@ def get_codebase_session( def get_codebase_graph_session( tmpdir: str, programming_language: ProgrammingLanguage = ProgrammingLanguage.PYTHON, - files: dict[str, str] = {}, + files: dict[str, str] | None = None, sync_graph: bool = True, session_options: SessionOptions = SessionOptions(), ) -> Generator[CodebaseGraph, None, None]: - """Gives you a Codebase2 operating on the files you provided as a dict""" + """Gives you a Codebase2 operating on the files you provided as a dict.""" + if files is None: + files = {} op = LocalRepoOperator.create_from_files(repo_path=tmpdir, files=files) config = CodebaseConfig(feature_flags=TestFlags) projects = [ProjectConfig(repo_operator=op, programming_language=programming_language)] diff --git a/src/codegen/sdk/codebase/flagging/code_flag.py b/src/codegen/sdk/codebase/flagging/code_flag.py index cb10a0057..b1d43b5c1 100644 --- a/src/codegen/sdk/codebase/flagging/code_flag.py +++ b/src/codegen/sdk/codebase/flagging/code_flag.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import TYPE_CHECKING @@ -27,9 +29,7 @@ def __eq__(self, other): return False if self.message != other.message: return False - if self.message_type != other.message_type: - return False - return True + return not self.message_type != other.message_type - def __repr__(self): + def __repr__(self) -> str: return f"" diff --git a/src/codegen/sdk/codebase/flagging/enums.py b/src/codegen/sdk/codebase/flagging/enums.py index 949f09ce5..8c405c17b 100644 --- a/src/codegen/sdk/codebase/flagging/enums.py +++ b/src/codegen/sdk/codebase/flagging/enums.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import IntFlag, auto from typing import TypedDict @@ -8,7 +10,7 @@ @apidoc class MessageType(IntFlag): - """Destination of the message + """Destination of the message. Attributes: CODEGEN: Rendered in the diff preview diff --git a/src/codegen/sdk/codebase/flagging/flags.py b/src/codegen/sdk/codebase/flagging/flags.py index 13288e40c..f943d065d 100644 --- a/src/codegen/sdk/codebase/flagging/flags.py +++ b/src/codegen/sdk/codebase/flagging/flags.py @@ -1,11 +1,16 @@ +from __future__ import annotations + from dataclasses import dataclass, field +from typing import TYPE_CHECKING from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.enums import MessageType -from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.core.interfaces.editable import Editable from codegen.shared.decorators.docs import noapidoc +if TYPE_CHECKING: + from codegen.sdk.codebase.flagging.group import Group + @dataclass class Flags: @@ -56,10 +61,9 @@ def should_fix(self, flag: CodeFlag) -> bool: """ if self._find_mode: return False - elif self._active_group is None: + if self._active_group is None: return True - else: - return flag.hash in self._active_group_hashes + return flag.hash in self._active_group_hashes @noapidoc def set_find_mode(self, find_mode: bool) -> None: @@ -71,4 +75,4 @@ def set_active_group(self, group: Group) -> None: # TODO - flesh this out more with Group datatype and GroupBy self._active_group = group.flags self._find_mode = False - self._active_group_hashes = set(flag.hash for flag in group.flags) + self._active_group_hashes = {flag.hash for flag in group.flags} diff --git a/src/codegen/sdk/codebase/flagging/group.py b/src/codegen/sdk/codebase/flagging/group.py index 58f6f9e95..16f49c9bd 100644 --- a/src/codegen/sdk/codebase/flagging/group.py +++ b/src/codegen/sdk/codebase/flagging/group.py @@ -1,9 +1,13 @@ +from __future__ import annotations + from dataclasses import dataclass +from typing import TYPE_CHECKING from dataclasses_json import dataclass_json -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +if TYPE_CHECKING: + from codegen.sdk.codebase.flagging.code_flag import CodeFlag + from codegen.sdk.codebase.flagging.groupers.enums import GroupBy DEFAULT_GROUP_ID = 0 diff --git a/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py index 64ef65d6e..fcdc247d4 100644 --- a/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py +++ b/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py @@ -1,9 +1,15 @@ -from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from __future__ import annotations + +from typing import TYPE_CHECKING + from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +if TYPE_CHECKING: + from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator + from codegen.sdk.codebase.flagging.code_flag import CodeFlag + class AllGrouper(BaseGrouper): """Group all flags into one group.""" diff --git a/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py index ec8b9904d..a20f2975e 100644 --- a/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py +++ b/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING -from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +if TYPE_CHECKING: + from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator + from codegen.sdk.codebase.flagging.code_flag import CodeFlag + logger = logging.getLogger(__name__) diff --git a/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py index f2b93017a..69a89cf25 100644 --- a/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py +++ b/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py @@ -1,13 +1,18 @@ -from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag -from codegen.sdk.codebase.flagging.group import Group -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator + from codegen.sdk.codebase.flagging.code_flag import CodeFlag + from codegen.sdk.codebase.flagging.group import Group + from codegen.sdk.codebase.flagging.groupers.enums import GroupBy class BaseGrouper: """Base class of all groupers. Children of this class should include in their doc string: - - a short desc of what the segment format is. ex: for FileGrouper the segment is a filename + - a short desc of what the segment format is. ex: for FileGrouper the segment is a filename. """ type: GroupBy @@ -24,6 +29,6 @@ def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | @staticmethod def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RemoteRepoOperator | None = None) -> Group: - """TODO: handle the case when 0 flags are passed in""" + """TODO: handle the case when 0 flags are passed in.""" msg = "Must implement create_single_group in BaseGrouper" raise NotImplementedError(msg) diff --git a/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py index d7ff0a51a..6666765c4 100644 --- a/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py +++ b/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py @@ -1,9 +1,15 @@ -from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from __future__ import annotations + +from typing import TYPE_CHECKING + from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +if TYPE_CHECKING: + from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator + from codegen.sdk.codebase.flagging.code_flag import CodeFlag + DEFAULT_CHUNK_SIZE = 5 diff --git a/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py index c4a198c5c..0392a8910 100644 --- a/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py +++ b/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py @@ -1,12 +1,17 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING -from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper from codegen.sdk.codebase.flagging.groupers.enums import GroupBy from codegen.shared.string.csv_utils import comma_separated_to_list, list_to_comma_separated +if TYPE_CHECKING: + from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator + from codegen.sdk.codebase.flagging.code_flag import CodeFlag + logger = logging.getLogger(__name__) DEFAULT_CHUNK_SIZE = 5 diff --git a/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py index 6cc537afa..34787dbf1 100644 --- a/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py +++ b/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING -from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator -from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +if TYPE_CHECKING: + from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator + from codegen.sdk.codebase.flagging.code_flag import CodeFlag + logger = logging.getLogger(__name__) @@ -19,7 +24,7 @@ class FileGrouper(BaseGrouper): @staticmethod def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | None = None) -> list[Group]: groups = [] - filenames = sorted(list({f.filepath for f in flags})) + filenames = sorted({f.filepath for f in flags}) for idx, filename in enumerate(filenames): filename_flags = [flag for flag in flags if flag.filepath == filename] groups.append(Group(id=idx, group_by=GroupBy.FILE, segment=filename, flags=filename_flags)) diff --git a/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py index 953c7f355..59a4de2bc 100644 --- a/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py +++ b/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py @@ -1,9 +1,15 @@ -from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from __future__ import annotations + +from typing import TYPE_CHECKING + from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +if TYPE_CHECKING: + from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator + class InstanceGrouper(BaseGrouper): """Group flags by flags. haha @@ -22,6 +28,6 @@ def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: Remo try: flag = CodeFlag.from_json(segment) return Group(group_by=GroupBy.INSTANCE, segment=segment, flags=[flag]) - except Exception as e: + except Exception: msg = f"Unable to deserialize segment ({segment}) into CodeFlag. Unable to create group." raise ValueError(msg) diff --git a/src/codegen/sdk/codebase/flagging/groupers/utils.py b/src/codegen/sdk/codebase/flagging/groupers/utils.py index 38d43cfa2..1f4f2f779 100644 --- a/src/codegen/sdk/codebase/flagging/groupers/utils.py +++ b/src/codegen/sdk/codebase/flagging/groupers/utils.py @@ -1,7 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from codegen.sdk.codebase.flagging.groupers.all_grouper import AllGrouper -from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper from codegen.sdk.codebase.flagging.groupers.constants import ALL_GROUPERS -from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + +if TYPE_CHECKING: + from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper + from codegen.sdk.codebase.flagging.groupers.enums import GroupBy def get_grouper_by_group_by(group_by: GroupBy | None) -> type[BaseGrouper]: diff --git a/src/codegen/sdk/codebase/range_index.py b/src/codegen/sdk/codebase/range_index.py index ea8cdd4b9..fbb47ac31 100644 --- a/src/codegen/sdk/codebase/range_index.py +++ b/src/codegen/sdk/codebase/range_index.py @@ -12,7 +12,7 @@ class RangeIndex: _ranges: defaultdict[Range, list[Editable]] _canonical_range: defaultdict[Range, dict[int, Editable]] - def __init__(self): + def __init__(self) -> None: self._ranges = defaultdict(list) self._canonical_range = defaultdict(dict) @@ -29,7 +29,7 @@ def get_canonical_for_range(self, range: Range, kind_id: int) -> Editable | None if mapping := self._canonical_range.get(range, None): return mapping.get(kind_id, None) - def clear(self): + def clear(self) -> None: self._ranges.clear() self._canonical_range.clear() self.__dict__.pop("children", None) diff --git a/src/codegen/sdk/codebase/transaction_manager.py b/src/codegen/sdk/codebase/transaction_manager.py index bba56b1bc..15728bd28 100644 --- a/src/codegen/sdk/codebase/transaction_manager.py +++ b/src/codegen/sdk/codebase/transaction_manager.py @@ -1,6 +1,6 @@ +from __future__ import annotations + import time -from collections.abc import Callable -from pathlib import Path from typing import TYPE_CHECKING from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite @@ -16,6 +16,9 @@ from codegen.shared.exceptions.control_flow import MaxPreviewTimeExceeded, MaxTransactionsExceeded if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + from codegen.sdk.core.file import File import logging @@ -43,11 +46,11 @@ class TransactionManager: stopwatch_max_seconds: int | None = None # None = no limit def __init__(self) -> None: - self.queued_transactions = dict() + self.queued_transactions = {} self.pending_undos = set() def sort_transactions(self) -> None: - for file_path, file_transactions in self.queued_transactions.items(): + for file_transactions in self.queued_transactions.values(): file_transactions.sort(key=Transaction._to_sort_key) def clear_transactions(self) -> None: @@ -65,7 +68,7 @@ def _format_transactions(self, transactions: list[Transaction]) -> str: return "\n".join([">" * 100 + f"\n[ID: {t.transaction_id}]: {t.diff_str()}" + "<" * 100 for t in transactions]) def get_transactions_str(self) -> str: - """Returns a human-readable string representation of the transactions""" + """Returns a human-readable string representation of the transactions.""" return "\n\n\n".join([f"{file_path}:\n{self._format_transactions(transactions)}" for file_path, transactions in self.queued_transactions.items()]) #################################################################################################################### @@ -73,7 +76,7 @@ def get_transactions_str(self) -> str: #################################################################################################################### def get_num_transactions(self) -> int: - """Returns total number of transactions created to date""" + """Returns total number of transactions created to date.""" return sum([len(transactions) for transactions in self.queued_transactions.values()]) def set_max_transactions(self, max_transactions: int | None = None) -> None: @@ -96,9 +99,8 @@ def reset_stopwatch(self, max_seconds: int | None = None) -> int: def is_time_exceeded(self) -> bool: if self.stopwatch_max_seconds is None: return False - else: - num_seconds = time.time() - self.stopwatch_start - return num_seconds > self.stopwatch_max_seconds + num_seconds = time.time() - self.stopwatch_start + return num_seconds > self.stopwatch_max_seconds #################################################################################################################### # Transaction Creation @@ -108,11 +110,11 @@ def add_file_add_transaction(self, filepath: Path) -> None: t = FileAddTransaction(filepath) self.add_transaction(t) - def add_file_rename_transaction(self, file: "File", new_filepath: str) -> None: + def add_file_rename_transaction(self, file: File, new_filepath: str) -> None: t = FileRenameTransaction(file, new_filepath) self.add_transaction(t) - def add_file_remove_transaction(self, file: "File") -> None: + def add_file_remove_transaction(self, file: File) -> None: t = FileRemoveTransaction(file) self.add_transaction(t) @@ -134,11 +136,11 @@ def add_transaction(self, transaction: Transaction, dedupe: bool = True, solve_c self.check_limits() return True - def check_limits(self): + def check_limits(self) -> None: self.check_max_transactions() self.check_max_preview_time() - def check_max_transactions(self): + def check_max_transactions(self) -> None: # =====[ Max transactions ]===== # max_transactions is set so that long-running codemods terminate early so we can quickly surface a subset # of the results to the user. This may result in errors that do not get covered. @@ -147,7 +149,7 @@ def check_max_transactions(self): msg = f"Max transactions reached: {self.max_transactions}" raise MaxTransactionsExceeded(msg, threshold=self.max_transactions) - def check_max_preview_time(self): + def check_max_preview_time(self) -> None: # =====[ Max preview time ]===== # This is to prevent the preview from taking too long. We want to keep it at like ~5s in the frontend during debugging if self.is_time_exceeded(): @@ -160,7 +162,7 @@ def check_max_preview_time(self): #################################################################################################################### def to_commit(self, files: set[Path] | None = None) -> set[Path]: - """Get node ids of files to commit""" + """Get node ids of files to commit.""" if files is None: return set(self.queued_transactions.keys()) return files.intersection(self.queued_transactions) @@ -170,7 +172,7 @@ def commit(self, files: set[Path]) -> list[DiffLite]: Returns the list of diffs that were committed. """ if self._commiting: - logger.warn("Skipping commit, already committing") + logger.warning("Skipping commit, already committing") return [] self._commiting = True try: @@ -232,21 +234,20 @@ def break_down(to_break: EditTransaction) -> bool: if isinstance(completely_overlapping, RemoveTransaction): return None # If the overlapping transaction is an edit, raise an error - elif isinstance(completely_overlapping, EditTransaction): + if isinstance(completely_overlapping, EditTransaction): if break_down(completely_overlapping): return transaction - raise TransactionError() - else: - # If current transaction is deleted, remove all conflicting transactions - if isinstance(transaction, RemoveTransaction): - for t in conflicts: - file_queue.remove(t) - # If current transaction is edit, raise an error - elif isinstance(transaction, EditTransaction): - if break_down(transaction): - return None - raise TransactionError() + raise TransactionError + # If current transaction is deleted, remove all conflicting transactions + elif isinstance(transaction, RemoveTransaction): + for t in conflicts: + file_queue.remove(t) + # If current transaction is edit, raise an error + elif isinstance(transaction, EditTransaction): + if break_down(transaction): + return None + raise TransactionError # Add to priority queue and rebuild the queue return transaction @@ -283,14 +284,13 @@ def get_transactions_at_range(self, file_path: Path, start_byte: int, end_byte: if t.end_byte == end_byte: if transaction_order is None or t.transaction_order == transaction_order: matching_transactions.append(t) - elif combined and t.start_byte != t.end_byte: - if other := self.get_transactions_at_range(t.file_path, t.end_byte, end_byte, transaction_order, combined=combined): - return [t, *other] + elif combined and t.start_byte != t.end_byte and (other := self.get_transactions_at_range(t.file_path, t.end_byte, end_byte, transaction_order, combined=combined)): + return [t, *other] return matching_transactions def _get_conflicts(self, transaction: Transaction) -> list[Transaction]: - """Returns all transactions that overlap with the given transaction""" + """Returns all transactions that overlap with the given transaction.""" overlapping_transactions = [] queued_transactions = list(self.queued_transactions[transaction.file_path]) for t in queued_transactions: @@ -299,7 +299,7 @@ def _get_conflicts(self, transaction: Transaction) -> list[Transaction]: return overlapping_transactions def _get_overlapping_conflicts(self, transaction: Transaction) -> Transaction | None: - """Returns the transaction that completely overlaps with the given transaction""" + """Returns the transaction that completely overlaps with the given transaction.""" for t in self.queued_transactions[transaction.file_path]: if transaction.start_byte >= t.start_byte and transaction.end_byte <= t.end_byte: return t diff --git a/src/codegen/sdk/codebase/transactions.py b/src/codegen/sdk/codebase/transactions.py index 1abdf912c..d8ecd9687 100644 --- a/src/codegen/sdk/codebase/transactions.py +++ b/src/codegen/sdk/codebase/transactions.py @@ -4,7 +4,7 @@ from enum import IntEnum from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING, NoReturn, Protocol, runtime_checkable from codegen.sdk.codebase.diff_lite import ChangeType, DiffLite @@ -78,7 +78,7 @@ def __eq__(self, other): def length(self): return self.end_byte - self.start_byte - def execute(self): + def execute(self) -> NoReturn: msg = "Transaction.execute() must be implemented by subclasses" raise NotImplementedError(msg) @@ -87,20 +87,20 @@ def get_diff(self) -> DiffLite: msg = "Transaction.get_diff() must be implemented by subclasses" raise NotImplementedError(msg) - def diff_str(self): + def diff_str(self) -> NoReturn: """Human-readable string representation of the change""" msg = "Transaction.diff_str() must be implemented by subclasses" raise NotImplementedError(msg) - def _to_sort_key(transaction: "Transaction"): + def _to_sort_key(self: "Transaction"): # Sort by: # 1. Descending start_byte # 2. Ascending transaction type # 3. Ascending priority # 4. Descending time of transaction= - priority = (transaction.priority,) if isinstance(transaction.priority, int) else transaction.priority + priority = (self.priority,) if isinstance(self.priority, int) else self.priority - return -transaction.start_byte, transaction.transaction_order.value, priority, -transaction.transaction_id + return -self.start_byte, self.transaction_order.value, priority, -self.transaction_id @cached_property def new_content(self) -> str | None: diff --git a/src/codegen/sdk/codebase/validation.py b/src/codegen/sdk/codebase/validation.py index ce3163528..4890c31cf 100644 --- a/src/codegen/sdk/codebase/validation.py +++ b/src/codegen/sdk/codebase/validation.py @@ -73,7 +73,7 @@ def post_sync_validation(codebase: CodebaseType) -> bool: ) -def log_or_throw(message, thread_message: str): +def log_or_throw(message, thread_message: str) -> None: hostname = socket.gethostname() logger.error(message) # logger.error(thread_message) diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 8aec1dc65..63d213b0a 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -178,7 +178,7 @@ def __str__(self) -> str: return f"" @noapidoc - def __repr__(self): + def __repr__(self) -> str: return str(self) def __rich_repr__(self) -> rich.repr.Result: @@ -846,7 +846,7 @@ def get_diff(self, base: str | None = None, stage_files: bool = False) -> str: return self._op.git_cli.git.diff(base, patch=True, full_index=True) @noapidoc - def clean_repo(self): + def clean_repo(self) -> None: """Cleaning a codebase repo by: 1. Deleting all branches except the checked out one 2. Deleting all remotes except origin @@ -859,12 +859,12 @@ def clean_repo(self): self._op.clean_branches() @noapidoc - def stash_changes(self): + def stash_changes(self) -> None: """Stash all changes in the codebase.""" self._op.stash_push() @noapidoc - def restore_stashed_changes(self): + def restore_stashed_changes(self) -> None: """Restore the most recent stash in the codebase.""" self._op.stash_pop() diff --git a/src/codegen/sdk/core/dataclasses/usage.py b/src/codegen/sdk/core/dataclasses/usage.py index 3cb9f618e..6b27b2f24 100644 --- a/src/codegen/sdk/core/dataclasses/usage.py +++ b/src/codegen/sdk/core/dataclasses/usage.py @@ -41,7 +41,7 @@ class Usage: @unique @apidoc class UsageType(IntFlag): - """Describes how a symbol is used elsewhere. Used in conjunction with get_usages + """Describes how a symbol is used elsewhere. Used in conjunction with get_usages. Attributes: DIRECT: Direct imports and usages within the same file diff --git a/src/codegen/sdk/core/detached_symbols/function_call.py b/src/codegen/sdk/core/detached_symbols/function_call.py index 7699343e9..4363db2ad 100644 --- a/src/codegen/sdk/core/detached_symbols/function_call.py +++ b/src/codegen/sdk/core/detached_symbols/function_call.py @@ -42,7 +42,7 @@ class FunctionCall(Expression[Parent], HasName, Resolvable, Generic[Parent]): """Abstract representation of a function invocation, e.g. in Python: ``` def f(): - g() # FunctionCall + g() # FunctionCall ``` """ @@ -588,12 +588,12 @@ def descendant_symbols(self) -> list[Importable]: @noapidoc @writer - def rename_if_matching(self, old: str, new: str): + def rename_if_matching(self, old: str, new: str) -> None: if name := self.get_name(): name.rename_if_matching(old, new) @noapidoc - def register_api_call(self, url: str): + def register_api_call(self, url: str) -> None: assert url, self self.G.global_context.multigraph.usages[url].append(self) diff --git a/src/codegen/sdk/core/directory.py b/src/codegen/sdk/core/directory.py index 0254b6af8..23fab13b7 100644 --- a/src/codegen/sdk/core/directory.py +++ b/src/codegen/sdk/core/directory.py @@ -47,7 +47,7 @@ class Directory(Generic[TFile, TSymbol, TImportStatement, TGlobalVar, TClass, TF parent: Self | None items: dict[str, TFile | Self] - def __init__(self, path: Path, dirpath: str, parent: Self | None): + def __init__(self, path: Path, dirpath: str, parent: Self | None) -> None: self.path = path self.dirpath = dirpath self.parent = parent @@ -95,7 +95,7 @@ def files(self) -> list[TFile]: """Get a recursive list of all files in the directory and its subdirectories.""" files = [] - def _get_files(directory: Directory): + def _get_files(directory: Directory) -> None: for item in directory.items.values(): if isinstance(item, Directory): _get_files(item) @@ -110,7 +110,7 @@ def subdirectories(self) -> list[Self]: """Get a recursive list of all subdirectories in the directory and its subdirectories.""" subdirectories = [] - def _get_subdirectories(directory: Directory): + def _get_subdirectories(directory: Directory) -> None: for item in directory.items.values(): if isinstance(item, Directory): subdirectories.append(item) diff --git a/src/codegen/sdk/core/expressions/await_expression.py b/src/codegen/sdk/core/expressions/await_expression.py index 15e2a9083..6f0bc0284 100644 --- a/src/codegen/sdk/core/expressions/await_expression.py +++ b/src/codegen/sdk/core/expressions/await_expression.py @@ -17,7 +17,7 @@ class AwaitExpression(Expression[Parent], HasValue, IWrapper, Generic[Parent]): """An awaited expression, only found in asynchronous contexts, e.g. await(foo(bar))""" - def __init__(self, ts_node, file_node_id, G, parent: Parent): + def __init__(self, ts_node, file_node_id, G, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent=parent) value_node = self.ts_node.named_children[0] self._value_node = self.G.parser.parse_expression(value_node, self.file_node_id, self.G, parent) if value_node else None diff --git a/src/codegen/sdk/core/expressions/boolean.py b/src/codegen/sdk/core/expressions/boolean.py index 6d8a4d9aa..843e58a0b 100644 --- a/src/codegen/sdk/core/expressions/boolean.py +++ b/src/codegen/sdk/core/expressions/boolean.py @@ -17,7 +17,7 @@ class Boolean(Expression[Parent], Builtin, Generic[Parent]): True, False """ - def __bool__(self): + def __bool__(self) -> bool: return self.ts_node.type == "true" @noapidoc diff --git a/src/codegen/sdk/core/expressions/chained_attribute.py b/src/codegen/sdk/core/expressions/chained_attribute.py index aec915dcd..7b5af3f68 100644 --- a/src/codegen/sdk/core/expressions/chained_attribute.py +++ b/src/codegen/sdk/core/expressions/chained_attribute.py @@ -34,7 +34,7 @@ class ChainedAttribute(Expression[Parent], Resolvable, Generic[Object, Attribute _object: Object _attribute: Attribute - def __init__(self, ts_node, file_node_id, G, parent: Parent, object: TSNode, attribute: TSNode): + def __init__(self, ts_node, file_node_id, G, parent: Parent, object: TSNode, attribute: TSNode) -> None: super().__init__(ts_node, file_node_id, G, parent=parent) self._object = self._parse_expression(object, default=Name) if self.G.parser._should_log: @@ -130,6 +130,6 @@ def descendant_symbols(self) -> list["Importable"]: @noapidoc @writer - def rename_if_matching(self, old: str, new: str): + def rename_if_matching(self, old: str, new: str) -> None: if self.attribute.source == old: self.attribute.edit(new) diff --git a/src/codegen/sdk/core/expressions/generic_type.py b/src/codegen/sdk/core/expressions/generic_type.py index 34f0d56f8..115c8143a 100644 --- a/src/codegen/sdk/core/expressions/generic_type.py +++ b/src/codegen/sdk/core/expressions/generic_type.py @@ -31,7 +31,7 @@ class GenericType(NamedType[Parent], Generic[TType, Parent]): _parameters: Collection[TType, Self] - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self._parameters = self._get_parameters() @@ -54,7 +54,7 @@ def _get_parameters(self) -> Collection[TType, Self]: @noapidoc @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: super()._compute_dependencies(usage_type, dest) for param in self._parameters: param._compute_dependencies(UsageKind.GENERIC, dest) diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index bb7ecbf19..efb646747 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -45,6 +45,6 @@ def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | @noapidoc @writer - def rename_if_matching(self, old: str, new: str): + def rename_if_matching(self, old: str, new: str) -> None: if self.source == old: self.edit(new) diff --git a/src/codegen/sdk/core/expressions/named_type.py b/src/codegen/sdk/core/expressions/named_type.py index 7628a1291..20c4479cf 100644 --- a/src/codegen/sdk/core/expressions/named_type.py +++ b/src/codegen/sdk/core/expressions/named_type.py @@ -26,7 +26,7 @@ class NamedType(Resolvable, Type[Parent], HasName, Generic[Parent]): """An abstract representation of a named type.""" - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self._name_node = self._parse_expression(self._get_name_node(), default=Name) @@ -56,7 +56,7 @@ def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: @noapidoc @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: if isinstance(self.get_name(), String): # TODO: string annotations self._log_parse("String type annotations are not currently supported") @@ -71,5 +71,5 @@ def descendant_symbols(self) -> list["Importable"]: @noapidoc @writer - def rename_if_matching(self, old: str, new: str): + def rename_if_matching(self, old: str, new: str) -> None: self.get_name().rename_if_matching(old, new) diff --git a/src/codegen/sdk/core/expressions/none_type.py b/src/codegen/sdk/core/expressions/none_type.py index 8a5956090..305f193b5 100644 --- a/src/codegen/sdk/core/expressions/none_type.py +++ b/src/codegen/sdk/core/expressions/none_type.py @@ -19,7 +19,7 @@ class NoneType(Type[Parent], Generic[Parent]): """Represents a None or Null object.""" @noapidoc - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: pass @reader diff --git a/src/codegen/sdk/core/expressions/parenthesized_expression.py b/src/codegen/sdk/core/expressions/parenthesized_expression.py index e4452a7c0..d814679a7 100644 --- a/src/codegen/sdk/core/expressions/parenthesized_expression.py +++ b/src/codegen/sdk/core/expressions/parenthesized_expression.py @@ -25,7 +25,7 @@ class ParenthesizedExpression(Unwrappable[Parent], HasValue, IWrapper, Generic[P ``` """ - def __init__(self, ts_node, file_node_id, G, parent: Parent): + def __init__(self, ts_node, file_node_id, G, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent=parent) value_node = self.ts_node.named_children[0] self._value_node = self.G.parser.parse_expression(value_node, self.file_node_id, self.G, self) if value_node else None diff --git a/src/codegen/sdk/core/expressions/placeholder_type.py b/src/codegen/sdk/core/expressions/placeholder_type.py index 5ce966f21..e7726b031 100644 --- a/src/codegen/sdk/core/expressions/placeholder_type.py +++ b/src/codegen/sdk/core/expressions/placeholder_type.py @@ -22,7 +22,7 @@ class PlaceholderType(Type[Parent], Generic[TType, Parent]): @noapidoc @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: self._add_all_identifier_usages(usage_type, dest=dest) @reader diff --git a/src/codegen/sdk/core/expressions/string.py b/src/codegen/sdk/core/expressions/string.py index 4101a3bf8..1ecf0d49f 100644 --- a/src/codegen/sdk/core/expressions/string.py +++ b/src/codegen/sdk/core/expressions/string.py @@ -45,7 +45,7 @@ def __eq__(self, other: object) -> bool: return True return super().__eq__(other) - def __str__(self): + def __str__(self) -> str: return self.content def __hash__(self): diff --git a/src/codegen/sdk/core/expressions/subscript_expression.py b/src/codegen/sdk/core/expressions/subscript_expression.py index 5de07b529..cf9fe143c 100644 --- a/src/codegen/sdk/core/expressions/subscript_expression.py +++ b/src/codegen/sdk/core/expressions/subscript_expression.py @@ -31,7 +31,7 @@ class SubscriptExpression(Expression[Parent], Resolvable[Parent], Generic[Object object: Object indices: list[Index] - def __init__(self, ts_node, file_node_id, G, parent: Parent): + def __init__(self, ts_node, file_node_id, G, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent=parent) self.object = self._parse_expression(self.ts_node.children[0], default=Name) self.indices = self.children[1:] diff --git a/src/codegen/sdk/core/expressions/tuple_type.py b/src/codegen/sdk/core/expressions/tuple_type.py index 315c9ed09..2994beae2 100644 --- a/src/codegen/sdk/core/expressions/tuple_type.py +++ b/src/codegen/sdk/core/expressions/tuple_type.py @@ -26,7 +26,7 @@ class TupleType(Collection[Type, Parent], Type[Parent], Generic[TType, Parent]): For example `[number, number]`. """ - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent, delimiter=" |") elements = list(self._get_types(ts_node)) self._init_children(elements) diff --git a/src/codegen/sdk/core/expressions/union_type.py b/src/codegen/sdk/core/expressions/union_type.py index 877a6ef9d..342059492 100644 --- a/src/codegen/sdk/core/expressions/union_type.py +++ b/src/codegen/sdk/core/expressions/union_type.py @@ -26,7 +26,7 @@ class UnionType(Collection[Type, Parent], Type[Parent], Generic[TType, Parent]): For example `str | None` or `string | number`. """ - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent, delimiter=" |") elements = list(self._get_types(ts_node)) self._init_children(elements) diff --git a/src/codegen/sdk/core/expressions/unpack.py b/src/codegen/sdk/core/expressions/unpack.py index 6a2c8b814..34f39b7be 100644 --- a/src/codegen/sdk/core/expressions/unpack.py +++ b/src/codegen/sdk/core/expressions/unpack.py @@ -22,7 +22,7 @@ class Unpack(Unwrappable[Parent], HasValue, IWrapper, Generic[Parent]): ``` """ - def __init__(self, ts_node, file_node_id, G, parent: Parent): + def __init__(self, ts_node, file_node_id, G, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent=parent) self._value_node = self.children[0] diff --git a/src/codegen/sdk/core/expressions/value.py b/src/codegen/sdk/core/expressions/value.py index b6116dd7e..9ba29e9e8 100644 --- a/src/codegen/sdk/core/expressions/value.py +++ b/src/codegen/sdk/core/expressions/value.py @@ -23,12 +23,12 @@ class Value(Expression[Parent], Generic[Parent]): See also HasValue. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.G.parser.log_unparsed(self.ts_node) @noapidoc @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None): + def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: for node in self.children: node._compute_dependencies(usage_type, dest=dest) diff --git a/src/codegen/sdk/core/external/dependency_manager.py b/src/codegen/sdk/core/external/dependency_manager.py index 85e250c34..d9f87de82 100644 --- a/src/codegen/sdk/core/external/dependency_manager.py +++ b/src/codegen/sdk/core/external/dependency_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import abstractmethod from typing import TYPE_CHECKING @@ -27,12 +29,11 @@ def remove_dependencies(self): pass -def get_dependency_manager(language: ProgrammingLanguage, codebase_graph: "CodebaseGraph", enabled: bool = False) -> DependencyManager | None: +def get_dependency_manager(language: ProgrammingLanguage, codebase_graph: CodebaseGraph, enabled: bool = False) -> DependencyManager | None: from codegen.sdk.typescript.external.dependency_manager import TypescriptDependencyManager ts_enabled = enabled or codebase_graph.config.feature_flags.ts_dependency_manager - if language == ProgrammingLanguage.TYPESCRIPT: - if ts_enabled: - return TypescriptDependencyManager(repo_path=codebase_graph.repo_path, base_path=codebase_graph.projects[0].base_path) + if language == ProgrammingLanguage.TYPESCRIPT and ts_enabled: + return TypescriptDependencyManager(repo_path=codebase_graph.repo_path, base_path=codebase_graph.projects[0].base_path) return None diff --git a/src/codegen/sdk/core/external/external_process.py b/src/codegen/sdk/core/external/external_process.py index 2b972f207..cbc9c4b04 100644 --- a/src/codegen/sdk/core/external/external_process.py +++ b/src/codegen/sdk/core/external/external_process.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import threading @@ -26,14 +28,14 @@ class ExternalProcess(ABC): is_ready: bool _error: BaseException | None - def __init__(self, repo_path: str, base_path: str | None = None): + def __init__(self, repo_path: str, base_path: str | None = None) -> None: self.repo_path: str = repo_path self.base_path: str | None = base_path self.full_path = os.path.join(repo_path, base_path) if base_path else repo_path self.is_ready: bool = False self._error: BaseException | None = None - def start(self, async_start: bool = False): + def start(self, async_start: bool = False) -> None: if async_start: # Create a new thread to start the engine thread = threading.Thread(target=self._start) @@ -45,7 +47,7 @@ def start(self, async_start: bool = False): def _start(self): pass - def reparse(self, async_start: bool = False): + def reparse(self, async_start: bool = False) -> None: # Reparse logic is handled by re-running start() self.is_ready = False self.start(async_start=async_start) @@ -56,7 +58,7 @@ def ready(self) -> bool: def error(self) -> BaseException | None: return self._error - def wait_until_ready(self, ignore_error: bool = False): + def wait_until_ready(self, ignore_error: bool = False) -> None: logger.info(f"Waiting for {self.__class__.__name__} to be ready...") # Wait for 3 minutes first start_time = time.time() diff --git a/src/codegen/sdk/core/external/language_engine.py b/src/codegen/sdk/core/external/language_engine.py index 8e2706460..1a391043c 100644 --- a/src/codegen/sdk/core/external/language_engine.py +++ b/src/codegen/sdk/core/external/language_engine.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import abstractmethod from typing import TYPE_CHECKING @@ -17,11 +19,11 @@ class LanguageEngine(ExternalProcess): """ @abstractmethod - def get_return_type(self, node: "Editable") -> str | None: + def get_return_type(self, node: Editable) -> str | None: pass -def get_language_engine(language: ProgrammingLanguage, codebase_graph: "CodebaseGraph", use_ts: bool = False, use_v8: bool = False) -> LanguageEngine | None: +def get_language_engine(language: ProgrammingLanguage, codebase_graph: CodebaseGraph, use_ts: bool = False, use_v8: bool = False) -> LanguageEngine | None: from codegen.sdk.typescript.external.ts_analyzer_engine import NodeTypescriptEngine, V8TypescriptEngine use_ts = use_ts or codebase_graph.config.feature_flags.ts_language_engine @@ -30,7 +32,7 @@ def get_language_engine(language: ProgrammingLanguage, codebase_graph: "Codebase if use_ts and use_v8: # Enables with both ts_language_engine and v8_ts_engine feature flags are on return V8TypescriptEngine(repo_path=codebase_graph.repo_path, base_path=codebase_graph.projects[0].base_path, dependency_manager=codebase_graph.dependency_manager) - elif use_ts: + if use_ts: # Enabled with only ts_language_engine feature flag is on return NodeTypescriptEngine(repo_path=codebase_graph.repo_path, base_path=codebase_graph.projects[0].base_path, dependency_manager=codebase_graph.dependency_manager) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 259bd9961..f989edcfc 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -503,7 +503,7 @@ def _compute_dependencies(self, *args, **kwargs) -> None: self.code_block._compute_dependencies() @noapidoc - def invalidate(self): + def invalidate(self) -> None: self.__dict__.pop("valid_symbol_names", None) self.__dict__.pop("valid_import_names", None) for imp in self.imports: @@ -812,11 +812,8 @@ def get_node_by_name(self, name: str) -> Symbol | TImport | None: def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImport]]: """Returns a dict mapping name => Symbol (or import) in this file.""" valid_symbol_names = {} - for s in self.symbols: - valid_symbol_names[s.full_name] = s - for imp in self.imports: - for name, dest in imp.names: - valid_symbol_names[name] = dest + valid_symbol_names.update({s.full_name: s for s in self.symbols}) + valid_symbol_names.update({name: dest for imp in self.imports for name, dest in imp.names}) return valid_symbol_names @noapidoc diff --git a/src/codegen/sdk/core/function.py b/src/codegen/sdk/core/function.py index d40cd1fa4..a8ccd2fd1 100644 --- a/src/codegen/sdk/core/function.py +++ b/src/codegen/sdk/core/function.py @@ -411,5 +411,5 @@ def descendant_symbols(self) -> Sequence[Importable]: return symbols @noapidoc - def register_api(self, url: str): + def register_api(self, url: str) -> None: self.G.global_context.multigraph.api_definitions[url] = self diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py index 432ec2ab5..1c56da7ff 100644 --- a/src/codegen/sdk/core/import_resolution.py +++ b/src/codegen/sdk/core/import_resolution.py @@ -403,6 +403,7 @@ def is_dynamic(self) -> bool: def my_function(): import foo # Dynamic - only imported when function runs + if condition: from bar import baz # Dynamic - only imported if condition is True @@ -671,7 +672,7 @@ class WildcardImport(Chainable, Generic[TImport]): imp: TImport symbol: Importable - def __init__(self, imp: TImport, symbol: Importable): + def __init__(self, imp: TImport, symbol: Importable) -> None: self.imp = imp self.symbol = symbol self.ts_node = imp.ts_node diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index 87378165c..e14385b4d 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -712,7 +712,7 @@ def _remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: boo removed_end_byte = max(parent.end_byte, removed_end_byte) parent = parent.parent - def should_keep(node: TSNode): + def should_keep(node: TSNode) -> bool: if node.type == "comment": # Remove comments on the same rows as the deleted node if node.end_point[0] <= self.end_point[0] and node.start_byte > removed_start_byte: @@ -970,7 +970,7 @@ def add_all_identifier_usages_for_child_node(self, usage_type: UsageKind, child: return self._add_symbol_usages(identifiers, usage_type, dest) @noapidoc - def _log_parse(self, msg: str, *args, **kwargs): + def _log_parse(self, msg: str, *args, **kwargs) -> None: self.G.parser.log(msg, *args, **kwargs) @property @@ -1123,10 +1123,7 @@ def parent_class(self) -> Class | None: def _get_ast_children(self) -> list[tuple[str | None, AST]]: children = [] - names = {} - for name, val in self._list_members(include_methods=True).items(): - if isinstance(val, Editable): - names[val] = name + names = {val: name for name, val in self._list_members(include_methods=True).items() if isinstance(val, Editable)} for child in self.file._range_index.get_children(self): if self.G.config.feature_flags.debug: assert child != self, child diff --git a/src/codegen/sdk/core/interfaces/has_name.py b/src/codegen/sdk/core/interfaces/has_name.py index e8c09be5a..73d496f09 100644 --- a/src/codegen/sdk/core/interfaces/has_name.py +++ b/src/codegen/sdk/core/interfaces/has_name.py @@ -90,7 +90,7 @@ def rename(self, name: str) -> None: @noapidoc @commiter - def _add_name_usage(self, usage_type: UsageKind): + def _add_name_usage(self, usage_type: UsageKind) -> None: if name := self.get_name(): if resolved := name.resolved_symbol(): self._add_symbol_usages(usage_type, [resolved]) diff --git a/src/codegen/sdk/core/parser.py b/src/codegen/sdk/core/parser.py index 37be5d4d0..6bb69f181 100644 --- a/src/codegen/sdk/core/parser.py +++ b/src/codegen/sdk/core/parser.py @@ -282,11 +282,11 @@ def parse_py_statements(self, node: TSNode, file_node_id: NodeId, G: CodebaseGra statements[-1].nested_code_blocks return statements - def report(self): + def report(self) -> None: if self._uncovered_nodes: self._console.print(f"Encountered unimplemented nodes {self._uncovered_nodes}") - def log(self, message: str, *args): + def log(self, message: str, *args) -> None: if self._should_log: try: self._console.log(message % args) diff --git a/src/codegen/sdk/core/plugins/axios.py b/src/codegen/sdk/core/plugins/axios.py index 22a4de747..12bd15a00 100644 --- a/src/codegen/sdk/core/plugins/axios.py +++ b/src/codegen/sdk/core/plugins/axios.py @@ -19,7 +19,7 @@ class AxiosApiFinder(Plugin): language: ProgrammingLanguage = ProgrammingLanguage.TYPESCRIPT - def execute(self, codebase: "TSCodebaseType"): + def execute(self, codebase: "TSCodebaseType") -> None: logger.info("Scanning for Axios API calls") api_calls = 0 diff --git a/src/codegen/sdk/core/plugins/flask.py b/src/codegen/sdk/core/plugins/flask.py index 1a399e6fd..eb5295b13 100644 --- a/src/codegen/sdk/core/plugins/flask.py +++ b/src/codegen/sdk/core/plugins/flask.py @@ -30,7 +30,7 @@ def extract_methods(decorator): class FlaskApiFinder(Plugin): language: ProgrammingLanguage = ProgrammingLanguage.PYTHON - def execute(self, codebase: "PyCodebaseType"): + def execute(self, codebase: "PyCodebaseType") -> None: logger.info("Scanning for flask endpoints") endpoints = 0 for func in codebase.functions: diff --git a/src/codegen/sdk/core/plugins/modal.py b/src/codegen/sdk/core/plugins/modal.py index d6edc2475..5e74f7e9b 100644 --- a/src/codegen/sdk/core/plugins/modal.py +++ b/src/codegen/sdk/core/plugins/modal.py @@ -13,7 +13,7 @@ class ModalApiFinder(Plugin): language: ProgrammingLanguage = ProgrammingLanguage.PYTHON - def execute(self, codebase: "PyCodebaseType"): + def execute(self, codebase: "PyCodebaseType") -> None: logger.info("Scanning for modal endpoints") endpoints = 0 for func in codebase.functions: diff --git a/src/codegen/sdk/core/plugins/plugin.py b/src/codegen/sdk/core/plugins/plugin.py index 79d9f6ed3..f48e34573 100644 --- a/src/codegen/sdk/core/plugins/plugin.py +++ b/src/codegen/sdk/core/plugins/plugin.py @@ -13,5 +13,5 @@ class Plugin(ABC): @abstractmethod def execute(self, codebase: "Codebase"): ... - def register_api(self, method: str, label: str, node: Editable): + def register_api(self, method: str, label: str, node: Editable) -> None: pass diff --git a/src/codegen/sdk/core/statements/comment.py b/src/codegen/sdk/core/statements/comment.py index 46c5f8c15..35744e0b7 100644 --- a/src/codegen/sdk/core/statements/comment.py +++ b/src/codegen/sdk/core/statements/comment.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, NoReturn, TypeVar from codegen.sdk.core.autocommit import commiter, reader, writer from codegen.sdk.core.dataclasses.usage import UsageKind @@ -119,7 +119,7 @@ def _parse_comment(self) -> str: @noapidoc @commiter - def _unparse_comment(self, new_src: str): + def _unparse_comment(self, new_src: str) -> NoReturn: """Unparses cleaned text content into a comment block.""" msg = "This method should be implemented by the subclass" raise NotImplementedError(msg) diff --git a/src/codegen/sdk/core/statements/export_statement.py b/src/codegen/sdk/core/statements/export_statement.py index 67d4fbbb3..f1a90c698 100644 --- a/src/codegen/sdk/core/statements/export_statement.py +++ b/src/codegen/sdk/core/statements/export_statement.py @@ -35,7 +35,7 @@ class ExportStatement(Statement["TSCodeBlock"], Generic[TExport]): exports: Collection[TExport, Self] statement_type = StatementType.EXPORT_STATEMENT - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: TSCodeBlock, pos: int): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: TSCodeBlock, pos: int) -> None: super().__init__(ts_node, parent.file_node_id, parent.G, parent, pos) export_node = self.ts_node if node := self.child_by_field_types(["export_clause", "export_statement"]): diff --git a/src/codegen/sdk/core/statements/expression_statement.py b/src/codegen/sdk/core/statements/expression_statement.py index 142fdf1f7..259dbd7bd 100644 --- a/src/codegen/sdk/core/statements/expression_statement.py +++ b/src/codegen/sdk/core/statements/expression_statement.py @@ -63,7 +63,7 @@ def function_calls(self) -> list[FunctionCall]: @noapidoc @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None): + def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: if self._value_node: self.resolve()._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/core/statements/statement.py b/src/codegen/sdk/core/statements/statement.py index 7d4175c6e..c9cf6e271 100644 --- a/src/codegen/sdk/core/statements/statement.py +++ b/src/codegen/sdk/core/statements/statement.py @@ -138,5 +138,5 @@ def _get_indent(self) -> int: @noapidoc @commiter - def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None): + def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = None) -> None: self._add_all_identifier_usages(usage_type, dest=dest) diff --git a/src/codegen/sdk/core/symbol_group.py b/src/codegen/sdk/core/symbol_group.py index c8340e012..2e4ee9869 100644 --- a/src/codegen/sdk/core/symbol_group.py +++ b/src/codegen/sdk/core/symbol_group.py @@ -37,7 +37,7 @@ def __init__(self, file_node_id: NodeId, G: CodebaseGraph, parent: Parent, node: node = children[0].ts_node super().__init__(node, file_node_id, G, parent) - def _init_children(self): ... + def _init_children(self) -> None: ... @repr_func # HACK def __hash__(self): diff --git a/src/codegen/sdk/core/symbol_groups/collection.py b/src/codegen/sdk/core/symbol_groups/collection.py index 5a349138b..565664251 100644 --- a/src/codegen/sdk/core/symbol_groups/collection.py +++ b/src/codegen/sdk/core/symbol_groups/collection.py @@ -47,7 +47,7 @@ def __init__(self, node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", paren if children is not None: self._init_children(children) - def _init_children(self, symbols: list[Child]): + def _init_children(self, symbols: list[Child]) -> None: """Call this after setting self._symbols.""" if self.ts_node.start_point[0] != self.ts_node.end_point[0] and symbols: # This is a multiline collection. @@ -264,11 +264,11 @@ def index(self, value: Child, start: int = 0, stop: int | None = None) -> int: return -1 @noapidoc - def _mark_dirty(self): + def _mark_dirty(self) -> None: self.transaction_manager.pending_undos.add(self.reset) @noapidoc - def reset(self): + def reset(self) -> None: self._pending_removes = 0 self._elements = len(self._original_children) self._symbols = self._original_children.copy() diff --git a/src/codegen/sdk/core/symbol_groups/expression_group.py b/src/codegen/sdk/core/symbol_groups/expression_group.py index 82f6da266..5a73f6e67 100644 --- a/src/codegen/sdk/core/symbol_groups/expression_group.py +++ b/src/codegen/sdk/core/symbol_groups/expression_group.py @@ -59,8 +59,4 @@ def function_calls(self) -> list[FunctionCall]: list[FunctionCall]: A list of all function calls found in the expressions of this group. """ - fcalls = [] - for expr in self.expressions: - for call in expr.function_calls: - fcalls.append(call) - return fcalls + return [call for expr in self.expressions for call in expr.function_calls] diff --git a/src/codegen/sdk/extensions/autocommit.pyi b/src/codegen/sdk/extensions/autocommit.pyi index fea4fad06..45b10dccd 100644 --- a/src/codegen/sdk/extensions/autocommit.pyi +++ b/src/codegen/sdk/extensions/autocommit.pyi @@ -14,8 +14,6 @@ def reader(wrapped: Callable[P, T]) -> Callable[P, T]: ... def reader(wrapped: None = None, *, cache: bool | None = ...) -> Callable[[Callable[P, T]], Callable[P, T]]: ... class AutoCommitMixin: - """Support for autocommit""" - autocommit_cache: dict[str, Any] removed: bool def __init__(self, G: CodebaseGraph) -> None: ... diff --git a/src/codegen/sdk/extensions/utils.pyi b/src/codegen/sdk/extensions/utils.pyi index 67e9d92ea..6a0080850 100644 --- a/src/codegen/sdk/extensions/utils.pyi +++ b/src/codegen/sdk/extensions/utils.pyi @@ -3,9 +3,7 @@ from functools import cached_property from tree_sitter import Node as TSNode -def get_all_identifiers(node: TSNode) -> list[TSNode]: - """Get all the identifiers in a tree-sitter node. Recursive implementation""" - +def get_all_identifiers(node: TSNode) -> list[TSNode]: ... def iter_all_descendants(node: TSNode, type_names: Iterable[str] | str, max_depth: int | None = None, nested: bool = True) -> Generator[TSNode, None, None]: ... def find_all_descendants( node: TSNode, @@ -13,9 +11,7 @@ def find_all_descendants( max_depth: int | None = None, nested: bool = True, ) -> list[TSNode]: ... -def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]: - """Returns a list of tuples of the start and end nodes of each line in the node""" - +def find_line_start_and_end_nodes(node: TSNode) -> list[tuple[TSNode, TSNode]]: ... def find_first_descendant(node: TSNode, type_names: list[str], max_depth: int | None = None) -> TSNode | None: ... cached_property = cached_property diff --git a/src/codegen/sdk/output/ast.py b/src/codegen/sdk/output/ast.py index da23005a7..23eee8742 100644 --- a/src/codegen/sdk/output/ast.py +++ b/src/codegen/sdk/output/ast.py @@ -1,9 +1,12 @@ -from typing import Self +from __future__ import annotations + +from typing import TYPE_CHECKING, Self from openai import BaseModel from pydantic import ConfigDict -from codegen.sdk.codebase.span import Span +if TYPE_CHECKING: + from codegen.sdk.codebase.span import Span class AST(BaseModel): diff --git a/src/codegen/sdk/output/utils.py b/src/codegen/sdk/output/utils.py index 5dfe8a21d..4b31e8178 100644 --- a/src/codegen/sdk/output/utils.py +++ b/src/codegen/sdk/output/utils.py @@ -1,17 +1,23 @@ +from __future__ import annotations + import json import sys from decimal import Decimal -from os import PathLike from pathlib import Path +from typing import TYPE_CHECKING from rich.console import Console, RenderResult from rich.syntax import Syntax from rich.text import Text -from tree_sitter import Node as TSNode -from tree_sitter import Point from codegen.sdk.output.constants import MAX_EDITABLE_LINES +if TYPE_CHECKING: + from os import PathLike + + from tree_sitter import Node as TSNode + from tree_sitter import Point + def style_editable(ts_node: TSNode, filepath: PathLike, file_node: TSNode) -> RenderResult: start_line = ts_node.start_point[0] + 1 # 1 based @@ -39,7 +45,7 @@ def _stylize_range(end_col, end_line, file_node, filepath, start_col, start_line return syntax -def stylize_error(path: PathLike, start: tuple[int, int] | Point, end: tuple[int, int] | Point, file_node: TSNode, content: str, message: str): +def stylize_error(path: PathLike, start: tuple[int, int] | Point, end: tuple[int, int] | Point, file_node: TSNode, content: str, message: str) -> None: Path(path).write_text(content) source = _stylize_range(end[1], end[0] + 1, file_node, path, start[1], start[0] + 1) console = Console(file=sys.stderr) @@ -61,7 +67,7 @@ def default(self, obj): if isinstance(obj, Decimal): return f"{obj:.10f}" if isinstance(obj, set): - return sorted(list(obj)) + return sorted(obj) if hasattr(obj, "__dict__"): return {key: self.default(value) for key, value in obj.__dict__.items()} return super().default(obj) @@ -71,14 +77,12 @@ def deterministic_json_dumps(data, **kwargs): def sort_dict(item): if isinstance(item, dict): return {key: sort_dict(value) for key, value in sorted(item.items())} - elif isinstance(item, list): + if isinstance(item, list): if len(item) > 0 and isinstance(item[0], dict): # Sort list of dictionaries based on all keys return sorted([sort_dict(i) for i in item], key=lambda x: json.dumps(x, sort_keys=True)) - else: - return [sort_dict(i) for i in item] - else: - return item + return [sort_dict(i) for i in item] + return item sorted_data = sort_dict(data) return json.dumps(sorted_data, cls=DeterministicJSONEncoder, **kwargs) diff --git a/src/codegen/sdk/python/detached_symbols/code_block.py b/src/codegen/sdk/python/detached_symbols/code_block.py index 4f09d2ad9..9238b4611 100644 --- a/src/codegen/sdk/python/detached_symbols/code_block.py +++ b/src/codegen/sdk/python/detached_symbols/code_block.py @@ -11,6 +11,7 @@ from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: + from codegen.sdk.core.interfaces.editable import Editable from codegen.sdk.python.assignment import PyAssignment from codegen.sdk.python.interfaces.has_block import PyHasBlock from codegen.sdk.python.statements.with_statement import WithStatement @@ -27,7 +28,7 @@ class PyCodeBlock(CodeBlock[Parent, "PyAssignment"], Generic[Parent]): @reader def _parse_statements(self) -> MultiLineCollection[Statement, Self]: statements: list[Statement] = self.G.parser.parse_py_statements(self.ts_node, self.file_node_id, self.G, self) - collection = MultiLineCollection( + return MultiLineCollection( children=statements, file_node_id=self.file_node_id, G=self.G, @@ -37,7 +38,6 @@ def _parse_statements(self) -> MultiLineCollection[Statement, Self]: leading_delimiter="", start_byte=self.start_byte - self.start_point[1], ) - return collection @property @reader @@ -65,13 +65,12 @@ def get_with_statements(self, level: int) -> list[WithStatement]: """ return [x for x in self.with_statements if x.parent.level == level] - def _smart_remove(self, child, *args, **kwargs) -> bool: + def _smart_remove(self, child: Editable, *args, **kwargs) -> bool: if len(self.statements) <= 1 and not isinstance(child, ImportStatement): if isinstance(self.parent, BlockStatement): self.parent.remove(*args, **kwargs) return True - else: - self.remove_byte_range(self.start_byte, self.end_byte) - self.parent.insert_after("pass", newline=False) - return True + self.remove_byte_range(self.start_byte, self.end_byte) + self.parent.insert_after("pass", newline=False) + return True return False diff --git a/src/codegen/sdk/python/detached_symbols/decorator.py b/src/codegen/sdk/python/detached_symbols/decorator.py index e56902edc..79c87680a 100644 --- a/src/codegen/sdk/python/detached_symbols/decorator.py +++ b/src/codegen/sdk/python/detached_symbols/decorator.py @@ -25,19 +25,13 @@ def _get_name_node(self) -> TSNode: for child in self.ts_node.children: # =====[ Identifier ]===== # Just `@dataclass` etc. - if child.type == "identifier": - return child - - # =====[ Attribute ]===== - # e.g. `@a.b` - elif child.type == "attribute": + if child.type in ("identifier", "attribute"): return child # =====[ Call ]===== # e.g. `@a.b()` - elif child.type == "call": - func = child.child_by_field_name("function") - return func + if child.type == "call": + return child.child_by_field_name("function") msg = f"Could not find decorator name within {self.source}" raise ValueError(msg) diff --git a/src/codegen/sdk/python/detached_symbols/parameter.py b/src/codegen/sdk/python/detached_symbols/parameter.py index 7eccd7aae..fde80c12e 100644 --- a/src/codegen/sdk/python/detached_symbols/parameter.py +++ b/src/codegen/sdk/python/detached_symbols/parameter.py @@ -27,9 +27,7 @@ def is_optional(self) -> bool: Returns: bool: True if the parameter is optional, False otherwise. """ - return ( - self.ts_node.type == "default_parameter" or self.ts_node.type == "typed_default_parameter" or self.ts_node.type == "list_splat_pattern" or self.ts_node.type == "dictionary_splat_pattern" - ) + return self.ts_node.type in ("default_parameter", "typed_default_parameter", "list_splat_pattern", "dictionary_splat_pattern") @property @reader @@ -41,7 +39,7 @@ def is_variadic(self) -> bool: Returns: bool: True if the parameter is variadic (uses * or ** syntax), False otherwise. """ - return self.ts_node.type == "list_splat_pattern" or self.ts_node.type == "dictionary_splat_pattern" + return self.ts_node.type in ("list_splat_pattern", "dictionary_splat_pattern") @deprecated("Use `type.edit` instead") @writer diff --git a/src/codegen/sdk/python/expressions/chained_attribute.py b/src/codegen/sdk/python/expressions/chained_attribute.py index b540ddee0..4df28f1a6 100644 --- a/src/codegen/sdk/python/expressions/chained_attribute.py +++ b/src/codegen/sdk/python/expressions/chained_attribute.py @@ -1,7 +1,11 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from tree_sitter import Node as TSNode + +from codegen.sdk.codebase.codebase_graph import CodebaseGraph from codegen.sdk.core.expressions import Expression, Name from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute +from codegen.sdk.core.node_id_factory import NodeId from codegen.shared.decorators.docs import py_apidoc if TYPE_CHECKING: @@ -16,5 +20,5 @@ class PyChainedAttribute(ChainedAttribute[Expression, Name, Parent], Generic[Par This includes methods of python classes and module functions. """ - def __init__(self, ts_node, file_node_id, G, parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent=parent, object=ts_node.child_by_field_name("object"), attribute=ts_node.child_by_field_name("attribute")) diff --git a/src/codegen/sdk/python/expressions/conditional_expression.py b/src/codegen/sdk/python/expressions/conditional_expression.py index 01cd0c73e..11cd45f2c 100644 --- a/src/codegen/sdk/python/expressions/conditional_expression.py +++ b/src/codegen/sdk/python/expressions/conditional_expression.py @@ -1,6 +1,10 @@ from typing import TYPE_CHECKING, TypeVar +from tree_sitter import Node as TSNode + +from codegen.sdk.codebase.codebase_graph import CodebaseGraph from codegen.sdk.core.expressions.ternary_expression import TernaryExpression +from codegen.sdk.core.node_id_factory import NodeId from codegen.shared.decorators.docs import py_apidoc if TYPE_CHECKING: @@ -11,9 +15,9 @@ @py_apidoc class PyConditionalExpression(TernaryExpression[Parent]): - """Conditional Expressions (A if condition else B)""" + """Conditional Expressions (A if condition else B).""" - def __init__(self, ts_node, file_node_id, G, parent: Parent) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent=parent) self.consequence = self.children[0] self.condition = self.children[1] diff --git a/src/codegen/sdk/python/expressions/generic_type.py b/src/codegen/sdk/python/expressions/generic_type.py index 1165c1c10..1203b0c47 100644 --- a/src/codegen/sdk/python/expressions/generic_type.py +++ b/src/codegen/sdk/python/expressions/generic_type.py @@ -1,6 +1,6 @@ -from typing import TYPE_CHECKING, Generic, Self, TypeVar +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, Generic, Self, TypeVar from codegen.sdk.core.expressions.generic_type import GenericType from codegen.sdk.core.symbol_groups.collection import Collection @@ -8,6 +8,8 @@ from codegen.shared.decorators.docs import py_apidoc if TYPE_CHECKING: + from tree_sitter import Node as TSNode + from codegen.sdk.python.expressions.type import PyType import logging @@ -32,11 +34,11 @@ def _get_name_node(self) -> TSNode | None: return self.child_by_field_types(["identifier", "attribute"]).ts_node return self.ts_node - def _get_parameters(self) -> Collection["PyType", Self] | None: + def _get_parameters(self) -> Collection[PyType, Self] | None: if self.ts_node_type == "subscript": types = [self._parse_type(child) for child in self.ts_node.children_by_field_name("subscript")] return Collection(node=self.ts_node, file_node_id=self.file_node_id, G=self.G, parent=self, children=types) - elif self.ts_node_type == "generic_type": + if self.ts_node_type == "generic_type": type_parameter = self.ts_node.named_children[1] assert type_parameter.type == "type_parameter" types = [self._parse_type(child) for child in type_parameter.named_children] diff --git a/src/codegen/sdk/python/expressions/named_type.py b/src/codegen/sdk/python/expressions/named_type.py index b2d1bd604..28e8ae74b 100644 --- a/src/codegen/sdk/python/expressions/named_type.py +++ b/src/codegen/sdk/python/expressions/named_type.py @@ -1,16 +1,19 @@ -from typing import Generic, TypeVar +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, Generic, TypeVar from codegen.sdk.core.expressions.named_type import NamedType from codegen.shared.decorators.docs import py_apidoc +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + Parent = TypeVar("Parent") @py_apidoc class PyNamedType(NamedType[Parent], Generic[Parent]): - """Named python type + """Named python type. Examples: int,str (builtin types) diff --git a/src/codegen/sdk/python/expressions/union_type.py b/src/codegen/sdk/python/expressions/union_type.py index d6181989e..8d19f4de5 100644 --- a/src/codegen/sdk/python/expressions/union_type.py +++ b/src/codegen/sdk/python/expressions/union_type.py @@ -11,10 +11,8 @@ @py_apidoc class PyUnionType(UnionType["PyType", Parent], Generic[Parent]): - """Union type + """Union type. Examples: str | int """ - - pass diff --git a/src/codegen/sdk/python/interfaces/has_block.py b/src/codegen/sdk/python/interfaces/has_block.py index 2871196eb..f1afb2d8f 100644 --- a/src/codegen/sdk/python/interfaces/has_block.py +++ b/src/codegen/sdk/python/interfaces/has_block.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import cached_property from codegen.sdk.core.autocommit import reader, writer diff --git a/src/codegen/sdk/python/statements/comment.py b/src/codegen/sdk/python/statements/comment.py index 390107212..33e2cb840 100644 --- a/src/codegen/sdk/python/statements/comment.py +++ b/src/codegen/sdk/python/statements/comment.py @@ -77,9 +77,7 @@ def _parse_comment(self) -> str: # Remove the triple quotes and extract the text content text_block = self.source[3:-3] # Parse the text block into lines - text_lines = [] - for line in text_block.lstrip("\n").split("\n"): - text_lines.append(line) + text_lines = list(text_block.lstrip("\n").split("\n")) # Get indentation level padding = lowest_indentation(text_lines, skip_lines=skip_lines) # Remove indentation diff --git a/src/codegen/sdk/python/symbol_groups/comment_group.py b/src/codegen/sdk/python/symbol_groups/comment_group.py index 701445fab..27f53dd53 100644 --- a/src/codegen/sdk/python/symbol_groups/comment_group.py +++ b/src/codegen/sdk/python/symbol_groups/comment_group.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self from docstring_parser import Docstring, DocstringStyle, parse @@ -21,7 +21,7 @@ @py_apidoc class PyCommentGroup(CommentGroup): - """A group of related symbols that represent a comment or docstring in Python + """A group of related symbols that represent a comment or docstring in Python. For example: ``` @@ -36,7 +36,7 @@ class PyCommentGroup(CommentGroup): @classmethod @noapidoc - def from_symbol_comments(cls, symbol: PySymbol): + def from_symbol_comments(cls, symbol: PySymbol) -> Self | None: siblings = symbol.parent.parent.statements comments = [] # Iterate backwards from the function node to collect all preceding comment nodes @@ -53,12 +53,11 @@ def from_symbol_comments(cls, symbol: PySymbol): from codegen.sdk.python.class_definition import PyClass # Check if the function node is a method - if symbol.symbol_type == SymbolType.Function: - if isinstance(symbol.parent_class, PyClass): - # Filter out the class docstring if it exists - if symbol.parent_class.docstring: - docstring_comments = set(symbol.parent_class.docstring.symbols) - comments = [c for c in comments if c not in docstring_comments] + if symbol.symbol_type == SymbolType.Function and isinstance(symbol.parent_class, PyClass): + # Filter out the class docstring if it exists + if symbol.parent_class.docstring: + docstring_comments = set(symbol.parent_class.docstring.symbols) + comments = [c for c in comments if c not in docstring_comments] if not comments: return None @@ -67,17 +66,16 @@ def from_symbol_comments(cls, symbol: PySymbol): @classmethod @noapidoc - def from_symbol_inline_comments(cls, symbol: PySymbol, node: TSNode | None = None): + def from_symbol_inline_comments(cls, symbol: PySymbol, node: TSNode | None = None) -> Self | None: statement = symbol.parent index = statement.index siblings = statement.parent.statements comment_nodes = [] # Check if there are any comments after the function node - if index + 1 < len(siblings): - if siblings[index + 1].statement_type == StatementType.COMMENT: - # Check if the comment is on the same line - if siblings[index].end_point[0] == siblings[index + 1].start_point[0]: - comment_nodes.append(siblings[index + 1]) + if index + 1 < len(siblings) and siblings[index + 1].statement_type == StatementType.COMMENT: + # Check if the comment is on the same line + if siblings[index].end_point[0] == siblings[index + 1].start_point[0]: + comment_nodes.append(siblings[index + 1]) if not comment_nodes: return None @@ -86,13 +84,13 @@ def from_symbol_inline_comments(cls, symbol: PySymbol, node: TSNode | None = Non @classmethod @noapidoc - def from_docstring(cls, symbol: PySymbol): + def from_docstring(cls, symbol: PySymbol) -> Self | None: # Check if there is an expression node above the symbol top_child = symbol.code_block.ts_node.children[0] if top_child.type == "expression_statement": string_node = top_child.children[0] if string_node.type == "string": - text = string_node.text.decode("utf-8") + string_node.text.decode("utf-8") comment_node = PyComment.from_code_block(string_node, symbol) return cls([comment_node], symbol.file_node_id, symbol.G, symbol) return None @@ -114,7 +112,7 @@ def to_google_docstring(self, function: PyFunction) -> str: # pragma: no cover NAME_OF_RETURNS_SECTION = "Returns:" def parse_google_block(section_header: str, first_line: str, docstring_iter) -> str: - """Parse the parameters section of the docstring""" + """Parse the parameters section of the docstring.""" unrelated_strings = [] parameters = {} @@ -137,7 +135,7 @@ def parse_google_block(section_header: str, first_line: str, docstring_iter) -> return unrelated_strings, parameters def merge_codebase_docstring(codebase_doc, parsed_doc): - """Merge the codebase docstring with the parsed docstring""" + """Merge the codebase docstring with the parsed docstring.""" for param_name, (param_type, param_description) in codebase_doc.items(): if param_name in parsed_doc: # Merge the types and descriptions diff --git a/src/codegen/sdk/ruff.toml b/src/codegen/sdk/ruff.toml new file mode 100644 index 000000000..a8f4c33d6 --- /dev/null +++ b/src/codegen/sdk/ruff.toml @@ -0,0 +1,10 @@ +extend = "../../../ruff.toml" + +[lint] +extend-select = ["PERF"] +extend-ignore = ["PERF401"] + + +[lint.per-file-ignores] +"skills/**.py" = ["PERF"] +"utils/**.py" = ["PERF"] diff --git a/src/codegen/sdk/tree_sitter_parser.py b/src/codegen/sdk/tree_sitter_parser.py index c9a04c6cc..361d174ea 100644 --- a/src/codegen/sdk/tree_sitter_parser.py +++ b/src/codegen/sdk/tree_sitter_parser.py @@ -82,7 +82,7 @@ def print_errors(filepath: PathLike, content: str) -> None: ts_node = parser.parse(bytes(content, "utf-8")).root_node if ts_node.has_error: - def traverse(node): + def traverse(node) -> None: if node.is_error or node.is_missing: stylize_error(filepath, node.start_point, node.end_point, ts_node, content, "with ts_node type of " + node.type) if node.has_error: diff --git a/src/codegen/sdk/typescript/config_parser.py b/src/codegen/sdk/typescript/config_parser.py index 2f5e79861..93d1340e4 100644 --- a/src/codegen/sdk/typescript/config_parser.py +++ b/src/codegen/sdk/typescript/config_parser.py @@ -19,7 +19,7 @@ class TSConfigParser(ConfigParser): config_files: dict[Path, TSConfig] G: "CodebaseGraph" - def __init__(self, codebase_graph: "CodebaseGraph", default_config_name: str = "tsconfig.json"): + def __init__(self, codebase_graph: "CodebaseGraph", default_config_name: str = "tsconfig.json") -> None: super().__init__() self.config_files = dict() self.G = codebase_graph @@ -34,7 +34,7 @@ def get_config(self, config_path: os.PathLike) -> TSConfig | None: return self.config_files.get(path) return None - def parse_configs(self): + def parse_configs(self) -> None: # This only yields a 0.05s speedup, but its funny writing dynamic programming code @cache def get_config_for_dir(dir_path: Path) -> TSConfig | None: diff --git a/src/codegen/sdk/typescript/detached_symbols/code_block.py b/src/codegen/sdk/typescript/detached_symbols/code_block.py index c69d73f73..dee9f401f 100644 --- a/src/codegen/sdk/typescript/detached_symbols/code_block.py +++ b/src/codegen/sdk/typescript/detached_symbols/code_block.py @@ -28,7 +28,7 @@ def _parse_statements(self) -> MultiLineCollection[Statement, Self]: start_node = line_nodes[1][0] if len(line_nodes) > 1 else line_nodes[0][0] end_node = line_nodes[-2][1] if len(line_nodes) > 1 else line_nodes[-1][1] indent_size = start_node.start_point[1] - collection = MultiLineCollection( + return MultiLineCollection( children=statements, file_node_id=self.file_node_id, G=self.G, @@ -39,12 +39,11 @@ def _parse_statements(self) -> MultiLineCollection[Statement, Self]: start_byte=start_node.start_byte - indent_size, end_byte=end_node.end_byte + 1, ) - return collection @reader @noapidoc def _get_line_starts(self) -> list[Editable]: - """Returns an ordered list of first Editable for each non-empty line within the code block""" + """Returns an ordered list of first Editable for each non-empty line within the code block.""" line_start_nodes = super()._get_line_starts() if len(line_start_nodes) >= 3 and line_start_nodes[0].source == "{" and line_start_nodes[-1].source == "}": # Remove the first and last line of the code block as they are opening and closing braces. @@ -54,7 +53,7 @@ def _get_line_starts(self) -> list[Editable]: @reader @noapidoc def _get_line_ends(self) -> list[Editable]: - """Returns an ordered list of last Editable for each non-empty line within the code block""" + """Returns an ordered list of last Editable for each non-empty line within the code block.""" line_end_nodes = super()._get_line_ends() # Remove the first and last line of the code block as they are opening and closing braces. return line_end_nodes[1:-1] diff --git a/src/codegen/sdk/typescript/detached_symbols/decorator.py b/src/codegen/sdk/typescript/detached_symbols/decorator.py index e06da867e..61c301794 100644 --- a/src/codegen/sdk/typescript/detached_symbols/decorator.py +++ b/src/codegen/sdk/typescript/detached_symbols/decorator.py @@ -17,7 +17,7 @@ @ts_apidoc class TSDecorator(Decorator["TSClass", "TSFunction", "TSParameter"]): - """Abstract representation of a Decorator""" + """Abstract representation of a Decorator.""" @reader def _get_name_node(self) -> TSNode: @@ -25,19 +25,13 @@ def _get_name_node(self) -> TSNode: for child in self.ts_node.children: # =====[ Identifier ]===== # Just `@dataclass` etc. - if child.type == "identifier": - return child - - # =====[ Attribute ]===== - # e.g. `@a.b` - elif child.type == "member_expression": + if child.type in ("identifier", "member_expression"): return child # =====[ Call ]===== # e.g. `@a.b()` - elif child.type == "call_expression": - func = child.child_by_field_name("function") - return func + if child.type == "call_expression": + return child.child_by_field_name("function") msg = f"Could not find decorator name within {self.source}" raise ValueError(msg) diff --git a/src/codegen/sdk/typescript/detached_symbols/jsx/element.py b/src/codegen/sdk/typescript/detached_symbols/jsx/element.py index 407e9f47b..e6fe42ee4 100644 --- a/src/codegen/sdk/typescript/detached_symbols/jsx/element.py +++ b/src/codegen/sdk/typescript/detached_symbols/jsx/element.py @@ -77,7 +77,7 @@ def expressions(self) -> list[JSXExpression]: @noapidoc @reader def _attribute_nodes(self) -> list[Editable]: - """Returns all attribute nodes of the element""" + """Returns all attribute nodes of the element.""" open_tag = self.ts_node.child_by_field_name("open_tag") or self.ts_node attribute_nodes = open_tag.children_by_field_name("attribute") return [Value(x, self.file_node_id, self.G, self) for x in attribute_nodes] @@ -175,7 +175,7 @@ def add_prop(self, prop_name: str, prop_value: str) -> None: @reader @noapidoc def _source(self): - """Text representation of the Editable instance""" + """Text representation of the Editable instance.""" return self.ts_node.text.decode("utf-8").strip() @writer diff --git a/src/codegen/sdk/typescript/detached_symbols/jsx/expression.py b/src/codegen/sdk/typescript/detached_symbols/jsx/expression.py index 833e1ec77..c075fbbc6 100644 --- a/src/codegen/sdk/typescript/detached_symbols/jsx/expression.py +++ b/src/codegen/sdk/typescript/detached_symbols/jsx/expression.py @@ -1,26 +1,28 @@ +from __future__ import annotations + from functools import cached_property from typing import TYPE_CHECKING, Self, override from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression -from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.unwrappable import Unwrappable from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from codegen.sdk.core.dataclasses.usage import UsageKind + from codegen.sdk.core.expressions import Expression from codegen.sdk.core.function import Function + from codegen.sdk.core.interfaces.editable import Editable + from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.typescript.detached_symbols.jsx.element import JSXElement from codegen.sdk.typescript.detached_symbols.jsx.prop import JSXProp @ts_apidoc class JSXExpression(Unwrappable["Function | JSXElement | JSXProp"]): - """Abstract representation of TSX/JSX expression""" + """Abstract representation of TSX/JSX expression.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.statement @@ -48,7 +50,6 @@ def _compute_dependencies(self, usage_type: UsageKind, dest: HasName | None = No def reduce_condition(self, bool_condition: bool, node: Editable) -> None: """Simplifies a JSX expression by reducing it based on a boolean condition. - Args: bool_condition (bool): The boolean value to reduce the condition to. @@ -63,7 +64,6 @@ def reduce_condition(self, bool_condition: bool, node: Editable) -> None: def unwrap(self, node: Expression | None = None) -> None: """Removes the brackets from a JSX expression. - Returns: None """ diff --git a/src/codegen/sdk/typescript/detached_symbols/jsx/prop.py b/src/codegen/sdk/typescript/detached_symbols/jsx/prop.py index b7521be55..cc7df2530 100644 --- a/src/codegen/sdk/typescript/detached_symbols/jsx/prop.py +++ b/src/codegen/sdk/typescript/detached_symbols/jsx/prop.py @@ -1,20 +1,22 @@ -from typing import TYPE_CHECKING, override +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, override from codegen.sdk.core.autocommit import reader, writer -from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.expressions import Expression from codegen.sdk.core.expressions.name import Name from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.has_value import HasValue from codegen.sdk.extensions.autocommit import commiter -from codegen.sdk.typescript.detached_symbols.jsx.expression import JSXExpression from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.function import Function from codegen.sdk.typescript.detached_symbols.jsx.element import JSXElement + from codegen.sdk.typescript.detached_symbols.jsx.expression import JSXExpression @ts_apidoc @@ -24,7 +26,7 @@ class JSXProp(Expression["Function | JSXElement | JSXProp"], HasName, HasValue): _name_node: Name | None _expression_node: JSXExpression | None - def __init__(self, ts_node: TSNode, parent: "Function | JSXElement | JSXProp") -> None: + def __init__(self, ts_node: TSNode, parent: Function | JSXElement | JSXProp) -> None: super().__init__(ts_node, parent.file_node_id, parent.G, parent) self._name_node = self._parse_expression(self.ts_node.children[0], default=Name) if len(self.ts_node.children) > 2: diff --git a/src/codegen/sdk/typescript/detached_symbols/parameter.py b/src/codegen/sdk/typescript/detached_symbols/parameter.py index 3a3a67dae..4f33e0584 100644 --- a/src/codegen/sdk/typescript/detached_symbols/parameter.py +++ b/src/codegen/sdk/typescript/detached_symbols/parameter.py @@ -68,8 +68,7 @@ def is_optional(self) -> bool: # In this case, individual destructured parameters are not marked as optional # The entire object might be optional, but that's handled at the function level return False - else: - return self.ts_node.type == "optional_parameter" + return self.ts_node.type == "optional_parameter" @property @reader @@ -101,8 +100,7 @@ def default(self) -> str | None: if self.is_destructured: if self.ts_node.type == "object_assignment_pattern": return self.ts_node.children[-1].text.decode("utf-8") - else: - return None + return None # =====[ Not destructured ]===== default_node = self.ts_node.child_by_field_name("value") @@ -114,9 +112,8 @@ def default(self) -> str | None: @commiter @override def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasName | None = None) -> None: - if self.type: - if not (self.is_destructured and self.index > 0): - self.type._compute_dependencies(UsageKind.TYPE_ANNOTATION, dest or self.parent.self_dest) + if self.type and not (self.is_destructured and self.index > 0): + self.type._compute_dependencies(UsageKind.TYPE_ANNOTATION, dest or self.parent.self_dest) if self.value: self.value._compute_dependencies(UsageKind.DEFAULT_VALUE, dest or self.parent.self_dest) diff --git a/src/codegen/sdk/typescript/export.py b/src/codegen/sdk/typescript/export.py index 24967852b..04fcff3b5 100644 --- a/src/codegen/sdk/typescript/export.py +++ b/src/codegen/sdk/typescript/export.py @@ -674,7 +674,7 @@ class WildcardExport(Chainable, Generic[TExport]): exp: TExport symbol: Exportable - def __init__(self, exp: TExport, symbol: Exportable): + def __init__(self, exp: TExport, symbol: Exportable) -> None: self.exp = exp self.symbol = symbol diff --git a/src/codegen/sdk/typescript/expressions/array_type.py b/src/codegen/sdk/typescript/expressions/array_type.py index 0fe714bbd..d9608b655 100644 --- a/src/codegen/sdk/typescript/expressions/array_type.py +++ b/src/codegen/sdk/typescript/expressions/array_type.py @@ -1,10 +1,13 @@ -from typing import TypeVar +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, TypeVar from codegen.sdk.typescript.expressions.named_type import TSNamedType from codegen.shared.decorators.docs import ts_apidoc +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + Parent = TypeVar("Parent") @@ -12,7 +15,7 @@ class TSArrayType(TSNamedType[Parent]): """Array type Examples: - string[] + string[]. """ def _get_name_node(self) -> TSNode | None: diff --git a/src/codegen/sdk/typescript/expressions/chained_attribute.py b/src/codegen/sdk/typescript/expressions/chained_attribute.py index 5277ceb20..bbab0406f 100644 --- a/src/codegen/sdk/typescript/expressions/chained_attribute.py +++ b/src/codegen/sdk/typescript/expressions/chained_attribute.py @@ -1,8 +1,12 @@ from typing import TYPE_CHECKING, Generic, TypeVar +from tree_sitter import Node as TSNode + +from codegen.sdk.codebase.codebase_graph import CodebaseGraph from codegen.sdk.core.detached_symbols.function_call import FunctionCall from codegen.sdk.core.expressions import Expression, Name from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute +from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.extensions.autocommit import reader from codegen.shared.decorators.docs import ts_apidoc @@ -21,7 +25,7 @@ class TSChainedAttribute(ChainedAttribute[Expression, Name, Parent], Generic[Par and property components of the expression, as well as analyzing function calls made on the object. """ - def __init__(self, ts_node, file_node_id, G, parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent=parent, object=ts_node.child_by_field_name("object"), attribute=ts_node.child_by_field_name("property")) @property diff --git a/src/codegen/sdk/typescript/expressions/conditional_type.py b/src/codegen/sdk/typescript/expressions/conditional_type.py index b17789526..4071ce4ee 100644 --- a/src/codegen/sdk/typescript/expressions/conditional_type.py +++ b/src/codegen/sdk/typescript/expressions/conditional_type.py @@ -21,7 +21,7 @@ @ts_apidoc class TSConditionalType(Type[Parent], Generic[Parent]): - """Conditional Type + """Conditional Type. Examples: typeof s @@ -32,14 +32,14 @@ class TSConditionalType(Type[Parent], Generic[Parent]): consequence: "TSType[Self]" alternative: "TSType[Self]" - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self.left = self.child_by_field_name("left") self.right = self.child_by_field_name("right") self.consequence = self.child_by_field_name("consequence") self.alternative = self.child_by_field_name("alternative") - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: self.left._compute_dependencies(usage_type, dest) self.right._compute_dependencies(usage_type, dest) self.consequence._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/typescript/expressions/expression_type.py b/src/codegen/sdk/typescript/expressions/expression_type.py index 6f40dd3e1..d89799bb5 100644 --- a/src/codegen/sdk/typescript/expressions/expression_type.py +++ b/src/codegen/sdk/typescript/expressions/expression_type.py @@ -16,7 +16,7 @@ @ts_apidoc class TSExpressionType(TSNamedType, Generic[Parent]): - """Type defined by evaluation of an expression + """Type defined by evaluation of an expression. Attributes: expression: The expression to evaluate that yields the type @@ -24,6 +24,6 @@ class TSExpressionType(TSNamedType, Generic[Parent]): expression: Expression["TSExpressionType[Parent]"] - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self.expression = self._parse_expression(ts_node) diff --git a/src/codegen/sdk/typescript/expressions/function_type.py b/src/codegen/sdk/typescript/expressions/function_type.py index 6b00e0510..c008b0c1b 100644 --- a/src/codegen/sdk/typescript/expressions/function_type.py +++ b/src/codegen/sdk/typescript/expressions/function_type.py @@ -1,21 +1,24 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override -from codegen.sdk.codebase.resolution_stack import ResolutionStack from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.core.symbol_groups.collection import Collection from codegen.sdk.typescript.detached_symbols.parameter import TSParameter from codegen.sdk.typescript.placeholder.placeholder_return_type import TSReturnTypePlaceholder from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from collections.abc import Generator + + from tree_sitter import Node as TSNode + from codegen.sdk.codebase.codebase_graph import CodebaseGraph + from codegen.sdk.codebase.resolution_stack import ResolutionStack + from codegen.sdk.core.interfaces.importable import Importable + from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.typescript.expressions.type import TSType @@ -33,11 +36,11 @@ class TSFunctionType(Type[Parent], Generic[Parent]): a: (a: number) => number """ - return_type: "TSType[Self] | TSReturnTypePlaceholder[Self]" + return_type: TSType[Self] | TSReturnTypePlaceholder[Self] _parameters: Collection[TSParameter, Self] name: None = None # This lets parameters generate their node_id properly - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self.return_type = self.child_by_field_name("return_type", placeholder=TSReturnTypePlaceholder) params_node = self.ts_node.child_by_field_name("parameters") @@ -74,7 +77,7 @@ def asyncify(self) -> None: self.return_type.insert_before("Promise<", newline=False) self.return_type.insert_after(">", newline=False) - def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: Importable | None = None): + def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: Importable | None = None) -> None: if self.return_type: self.return_type._compute_dependencies(UsageKind.GENERIC, dest) diff --git a/src/codegen/sdk/typescript/expressions/generic_type.py b/src/codegen/sdk/typescript/expressions/generic_type.py index 71c0b9971..0a89521bb 100644 --- a/src/codegen/sdk/typescript/expressions/generic_type.py +++ b/src/codegen/sdk/typescript/expressions/generic_type.py @@ -1,13 +1,15 @@ -from typing import TYPE_CHECKING, Self, TypeVar +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, Self, TypeVar from codegen.sdk.core.expressions.generic_type import GenericType from codegen.sdk.core.symbol_groups.collection import Collection -from codegen.sdk.core.symbol_groups.dict import Dict from codegen.shared.decorators.docs import ts_apidoc if TYPE_CHECKING: + from tree_sitter import Node as TSNode + + from codegen.sdk.core.symbol_groups.dict import Dict from codegen.sdk.typescript.expressions.type import TSType Parent = TypeVar("Parent") @@ -15,7 +17,7 @@ @ts_apidoc class TSGenericType(GenericType["TSType", Parent]): - """Generic type + """Generic type. Examples: `Array` diff --git a/src/codegen/sdk/typescript/expressions/lookup_type.py b/src/codegen/sdk/typescript/expressions/lookup_type.py index 7ee97f6d6..a76dbf976 100644 --- a/src/codegen/sdk/typescript/expressions/lookup_type.py +++ b/src/codegen/sdk/typescript/expressions/lookup_type.py @@ -1,19 +1,22 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override -from codegen.sdk.codebase.resolution_stack import ResolutionStack from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions import Expression from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from collections.abc import Generator + + from tree_sitter import Node as TSNode + from codegen.sdk.codebase.codebase_graph import CodebaseGraph + from codegen.sdk.codebase.resolution_stack import ResolutionStack + from codegen.sdk.core.dataclasses.usage import UsageKind + from codegen.sdk.core.expressions import Expression + from codegen.sdk.core.interfaces.importable import Importable + from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.typescript.expressions.type import TSType @@ -24,13 +27,13 @@ class TSLookupType(Type[Parent], Generic[Parent]): """Type lookup Examples: - a["key"] + a["key"]. """ - type: "TSType[Self]" + type: TSType[Self] lookup: Expression - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self.type = self._parse_type(ts_node.named_children[0]) if literal_type := self.child_by_field_types("literal_type"): @@ -56,5 +59,5 @@ def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: self._log_parse("Cannot resolve lookup type properly") yield from self.with_resolution_frame(self.type) - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: self.type._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/typescript/expressions/named_type.py b/src/codegen/sdk/typescript/expressions/named_type.py index 223f61de5..b7e4c4c85 100644 --- a/src/codegen/sdk/typescript/expressions/named_type.py +++ b/src/codegen/sdk/typescript/expressions/named_type.py @@ -1,10 +1,13 @@ -from typing import TypeVar +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, TypeVar from codegen.sdk.core.expressions.named_type import NamedType from codegen.shared.decorators.docs import ts_apidoc +if TYPE_CHECKING: + from tree_sitter import Node as TSNode + Parent = TypeVar("Parent") @@ -12,7 +15,7 @@ class TSNamedType(NamedType[Parent]): """Named type Examples: - string + string. """ def _get_name_node(self) -> TSNode | None: diff --git a/src/codegen/sdk/typescript/expressions/object_type.py b/src/codegen/sdk/typescript/expressions/object_type.py index 63f5c3cd6..3d048265d 100644 --- a/src/codegen/sdk/typescript/expressions/object_type.py +++ b/src/codegen/sdk/typescript/expressions/object_type.py @@ -1,18 +1,20 @@ -from typing import TYPE_CHECKING, Generic, Self, TypeVar +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, Generic, Self, TypeVar -from codegen.sdk.core.dataclasses.usage import UsageKind -from codegen.sdk.core.expressions.expression import Expression from codegen.sdk.core.expressions.type import Type from codegen.sdk.core.expressions.value import Value -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.typescript.symbol_groups.dict import TSDict, TSPair from codegen.shared.decorators.docs import ts_apidoc if TYPE_CHECKING: + from tree_sitter import Node as TSNode + from codegen.sdk.codebase.codebase_graph import CodebaseGraph + from codegen.sdk.core.dataclasses.usage import UsageKind + from codegen.sdk.core.expressions.expression import Expression + from codegen.sdk.core.interfaces.importable import Importable + from codegen.sdk.core.node_id_factory import NodeId import logging @@ -23,7 +25,7 @@ class TSObjectPair(TSPair, Generic[Parent]): - """Object type + """Object type. Examples: a: {a: int; b?(a: int): c} @@ -72,10 +74,10 @@ class TSObjectType(TSDict, Type[Parent], Generic[Parent]): in TypeScript code. """ - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent, delimiter=";", pair_type=TSObjectPair) - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: for child in self.values(): if isinstance(child, Type): child._compute_dependencies(usage_type, dest) diff --git a/src/codegen/sdk/typescript/expressions/query_type.py b/src/codegen/sdk/typescript/expressions/query_type.py index 4c34f849e..a08c71b15 100644 --- a/src/codegen/sdk/typescript/expressions/query_type.py +++ b/src/codegen/sdk/typescript/expressions/query_type.py @@ -1,18 +1,21 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override -from codegen.sdk.codebase.resolution_stack import ResolutionStack from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from collections.abc import Generator + + from tree_sitter import Node as TSNode + from codegen.sdk.codebase.codebase_graph import CodebaseGraph + from codegen.sdk.codebase.resolution_stack import ResolutionStack + from codegen.sdk.core.dataclasses.usage import UsageKind + from codegen.sdk.core.interfaces.importable import Importable + from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.typescript.expressions.type import TSType @@ -21,15 +24,15 @@ @ts_apidoc class TSQueryType(Type[Parent], Generic[Parent]): - """Type query + """Type query. Examples: typeof s """ - query: "TSType[Self]" + query: TSType[Self] - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self.query = self._parse_type(ts_node.named_children[0]) @@ -46,7 +49,7 @@ def name(self) -> str | None: """ return self.query.name - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: self.query._compute_dependencies(usage_type, dest) @reader diff --git a/src/codegen/sdk/typescript/expressions/readonly_type.py b/src/codegen/sdk/typescript/expressions/readonly_type.py index 10079de66..46b58f8f7 100644 --- a/src/codegen/sdk/typescript/expressions/readonly_type.py +++ b/src/codegen/sdk/typescript/expressions/readonly_type.py @@ -1,18 +1,21 @@ -from collections.abc import Generator -from typing import TYPE_CHECKING, Generic, Self, TypeVar, override +from __future__ import annotations -from tree_sitter import Node as TSNode +from typing import TYPE_CHECKING, Generic, Self, TypeVar, override -from codegen.sdk.codebase.resolution_stack import ResolutionStack from codegen.sdk.core.autocommit import reader -from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.expressions.type import Type -from codegen.sdk.core.interfaces.importable import Importable -from codegen.sdk.core.node_id_factory import NodeId from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from collections.abc import Generator + + from tree_sitter import Node as TSNode + from codegen.sdk.codebase.codebase_graph import CodebaseGraph + from codegen.sdk.codebase.resolution_stack import ResolutionStack + from codegen.sdk.core.dataclasses.usage import UsageKind + from codegen.sdk.core.interfaces.importable import Importable + from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.typescript.expressions.type import TSType @@ -21,15 +24,15 @@ @ts_apidoc class TSReadonlyType(Type[Parent], Generic[Parent]): - """Readonly type + """Readonly type. Examples: readonly s """ - type: "TSType[Self]" + type: TSType[Self] - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent): + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self.type = self._parse_type(ts_node.named_children[0]) @@ -46,7 +49,7 @@ def name(self) -> str | None: """ return self.type.name - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: self.type._compute_dependencies(usage_type, dest) @reader diff --git a/src/codegen/sdk/typescript/expressions/ternary_expression.py b/src/codegen/sdk/typescript/expressions/ternary_expression.py index aa141978d..4b73f56e0 100644 --- a/src/codegen/sdk/typescript/expressions/ternary_expression.py +++ b/src/codegen/sdk/typescript/expressions/ternary_expression.py @@ -1,6 +1,10 @@ from typing import TYPE_CHECKING, TypeVar +from tree_sitter import Node as TSNode + +from codegen.sdk.codebase.codebase_graph import CodebaseGraph from codegen.sdk.core.expressions.ternary_expression import TernaryExpression +from codegen.sdk.core.node_id_factory import NodeId from codegen.shared.decorators.docs import ts_apidoc if TYPE_CHECKING: @@ -11,9 +15,9 @@ @ts_apidoc class TSTernaryExpression(TernaryExpression[Parent]): - """Any ternary expression in the code where a condition will determine branched execution""" + """Any ternary expression in the code where a condition will determine branched execution.""" - def __init__(self, ts_node, file_node_id, G, parent: Parent) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent=parent) self.condition = self.child_by_field_name("condition") self.consequence = self.child_by_field_name("consequence") diff --git a/src/codegen/sdk/typescript/expressions/undefined_type.py b/src/codegen/sdk/typescript/expressions/undefined_type.py index 0a0abd49e..1419de4f1 100644 --- a/src/codegen/sdk/typescript/expressions/undefined_type.py +++ b/src/codegen/sdk/typescript/expressions/undefined_type.py @@ -15,11 +15,11 @@ class TSUndefinedType(Type[Parent], Generic[Parent]): """Undefined type. Represents the undefined keyword Examples: - undefined + undefined. """ @noapidoc - def _compute_dependencies(self, usage_type: UsageKind, dest: Importable): + def _compute_dependencies(self, usage_type: UsageKind, dest: Importable) -> None: pass @reader diff --git a/src/codegen/sdk/typescript/expressions/union_type.py b/src/codegen/sdk/typescript/expressions/union_type.py index b0df5b24d..4e89f913b 100644 --- a/src/codegen/sdk/typescript/expressions/union_type.py +++ b/src/codegen/sdk/typescript/expressions/union_type.py @@ -11,10 +11,8 @@ @ts_apidoc class TSUnionType(UnionType["TSType", Parent], Generic[Parent]): - """Union type + """Union type. Examples: string | number """ - - pass diff --git a/src/codegen/sdk/typescript/external/dependency_manager.py b/src/codegen/sdk/typescript/external/dependency_manager.py index d968a2bac..02ceba264 100644 --- a/src/codegen/sdk/typescript/external/dependency_manager.py +++ b/src/codegen/sdk/typescript/external/dependency_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import concurrent.futures import json import logging @@ -39,7 +41,7 @@ class TypescriptDependencyManager(DependencyManager): """Handles dependency management for Typescript projects. Uses npm, yarn, or pnpm if applicable.""" - def __init__(self, repo_path: str, base_path: str | None = None, should_install_dependencies: bool = True, force_installer: str | None = None): + def __init__(self, repo_path: str, base_path: str | None = None, should_install_dependencies: bool = True, force_installer: str | None = None) -> None: super().__init__(repo_path, base_path) logger.info(f"Initializing TypescriptDependencyManager with should_install_dependencies={should_install_dependencies}") # Ensure that node, npm, yarn, and pnpm are installed @@ -75,14 +77,13 @@ def __init__(self, repo_path: str, base_path: str | None = None, should_install_ def _detect_installer_type(self) -> InstallerType: if os.path.exists(os.path.join(self.full_path, "yarn.lock")): return InstallerType.YARN - elif os.path.exists(os.path.join(self.full_path, "package-lock.json")): + if os.path.exists(os.path.join(self.full_path, "package-lock.json")): return InstallerType.NPM - elif os.path.exists(os.path.join(self.full_path, "pnpm-lock.yaml")): + if os.path.exists(os.path.join(self.full_path, "pnpm-lock.yaml")): return InstallerType.PNPM - else: - logger.warning("Could not detect installer type. Defaulting to NPM!") - return InstallerType.NPM - # return InstallerType.UNKNOWN + logger.warning("Could not detect installer type. Defaulting to NPM!") + return InstallerType.NPM + # return InstallerType.UNKNOWN @staticmethod def _check_package_exists(package_name: str) -> bool: @@ -120,12 +121,12 @@ def _validate_dependencies(cls, deps: dict[str, str]) -> tuple[dict[str, str], d return valid_deps, invalid_deps - def parse_dependencies(self): + def parse_dependencies(self) -> None: # Clear the package_json_data self.package_json_data.clear() # Walk through directory tree - for current_dir, subdirs, files in os.walk(self.full_path): + for current_dir, _subdirs, files in os.walk(self.full_path): # Skip node_modules directories if "node_modules" in current_dir: continue @@ -151,14 +152,14 @@ def parse_dependencies(self): logger.exception(f"Could not find package.json at {package_json_path}") except ValueError: logger.exception(f"Invalid json in package.json at {package_json_path}") - except Exception as e: - raise e + except Exception: + raise # Set the base package.json data base_package_json_path = os.path.join(self.full_path, "package.json") self.base_package_json_data = self.package_json_data.get(base_package_json_path, None) - def _install_dependencies_npm(self): + def _install_dependencies_npm(self) -> None: logger.info("Installing dependencies with NPM") # Shadow package-lock.json, if it exists files_to_shadow = [] @@ -188,7 +189,7 @@ def _install_dependencies_npm(self): logger.exception(f"NPM FAIL stderr: {e.stderr}") raise - def _install_dependencies_yarn(self): + def _install_dependencies_yarn(self) -> None: logger.info("Installing dependencies with Yarn") # Shadow yarn.lock, yarn.config.cjs, and .yarnrc.yml, if they exist files_to_shadow = [] @@ -268,7 +269,7 @@ def _install_dependencies_yarn(self): # Delete the .yarnrc.yml file os.remove(os.path.join(self.full_path, ".yarnrc.yml")) - def _install_dependencies_pnpm(self): + def _install_dependencies_pnpm(self) -> None: logger.info("Installing dependencies with PNPM") # Shadow pnpm-lock.yaml, if it exists files_to_shadow = [] @@ -297,7 +298,7 @@ def _install_dependencies_pnpm(self): logger.exception(f"PNPM FAIL stderr: {e.stderr}") raise - def _clean_package_json(self, package_json_path: str): + def _clean_package_json(self, package_json_path: str) -> None: # Get the package data data = self.package_json_data[package_json_path] @@ -325,7 +326,7 @@ def _clean_package_json(self, package_json_path: str): json_str = json.dumps(clean_package_data, indent=2) f.write(json_str) - def install_dependencies(self, validate_dependencies: bool = True): + def install_dependencies(self, validate_dependencies: bool = True) -> None: if validate_dependencies: with shadow_files(list(self.package_json_data.keys())): logger.info(f"Cleaning package.json files: {list(self.package_json_data.keys())}") @@ -334,23 +335,23 @@ def install_dependencies(self, validate_dependencies: bool = True): # Install dependencies, now that we have a valid package.json return self.install_dependencies(validate_dependencies=False) + elif self.installer_type == InstallerType.NPM: + return self._install_dependencies_npm() + elif self.installer_type == InstallerType.YARN: + return self._install_dependencies_yarn() + elif self.installer_type == InstallerType.PNPM: + return self._install_dependencies_pnpm() else: - if self.installer_type == InstallerType.NPM: - return self._install_dependencies_npm() - elif self.installer_type == InstallerType.YARN: - return self._install_dependencies_yarn() - elif self.installer_type == InstallerType.PNPM: - return self._install_dependencies_pnpm() - else: - logger.warning(f"Installer type {self.installer_type} not implemented") - - def remove_dependencies(self): + logger.warning(f"Installer type {self.installer_type} not implemented") + return None + + def remove_dependencies(self) -> None: # Delete node_modules folder if it exists node_modules_path = os.path.join(self.full_path, "node_modules") if os.path.exists(node_modules_path): shutil.rmtree(node_modules_path) - def _start(self): + def _start(self) -> None: try: logger.info(f"Starting TypescriptDependencyManager with should_install_dependencies={self.should_install_dependencies}") super()._start() diff --git a/src/codegen/sdk/typescript/external/ts_analyzer_engine.py b/src/codegen/sdk/typescript/external/ts_analyzer_engine.py index 921cca499..ce0ab1c19 100644 --- a/src/codegen/sdk/typescript/external/ts_analyzer_engine.py +++ b/src/codegen/sdk/typescript/external/ts_analyzer_engine.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import os @@ -8,14 +10,15 @@ from pathlib import Path from typing import TYPE_CHECKING -from py_mini_racer import MiniRacer -from py_mini_racer._objects import JSMappedObject from py_mini_racer._types import JSEvalException from codegen.sdk.core.external.language_engine import LanguageEngine from codegen.sdk.typescript.external.mega_racer import MegaRacer if TYPE_CHECKING: + from py_mini_racer import MiniRacer + from py_mini_racer._objects import JSMappedObject + from codegen.sdk.core.external.dependency_manager import DependencyManager from codegen.sdk.core.interfaces.editable import Editable @@ -24,14 +27,14 @@ class TypescriptEngine(LanguageEngine): - dependency_manager: "DependencyManager | None" + dependency_manager: DependencyManager | None - def __init__(self, repo_path: str, base_path: str | None = None, dependency_manager: "DependencyManager | None" = None): + def __init__(self, repo_path: str, base_path: str | None = None, dependency_manager: DependencyManager | None = None) -> None: super().__init__(repo_path, base_path) self.dependency_manager = dependency_manager @abstractmethod - def _start(self): + def _start(self) -> None: # If a dependency manager is provided, make sure it is ready if self.dependency_manager: logger.info(f"TypescriptEngine: Waiting for {self.dependency_manager.__class__.__name__} to be ready...") @@ -59,10 +62,10 @@ def __init__( self, repo_path: str, base_path: str | None = None, - dependency_manager: "DependencyManager | None" = None, + dependency_manager: DependencyManager | None = None, hard_memory_limit: int = 1024 * 1024 * 1024 * 16, soft_memory_limit: int = 1024 * 1024 * 1024 * 8, - ): + ) -> None: super().__init__(repo_path, base_path, dependency_manager) logger.info(f"Initializing V8TypescriptEngine with hard_memory_limit={hard_memory_limit} and soft_memory_limit={soft_memory_limit}") self.hard_memory_limit: int = hard_memory_limit @@ -79,7 +82,7 @@ def __init__( self.engine_source: str = open(self.engine_path).read() self._patch_engine_source() - def _start(self): + def _start(self) -> None: try: logger.info("Starting V8TypescriptEngine") super()._start() @@ -114,7 +117,7 @@ def _start(self): self._error = e logger.error(f"Error starting V8TypescriptEngine: {e}", exc_info=True) - def _populate_fs_files(self, fs_files: dict): + def _populate_fs_files(self, fs_files: dict) -> None: for root, _, files in os.walk(self.full_path): for filename in files: file_path = Path(root) / filename @@ -126,16 +129,15 @@ def _populate_fs_files(self, fs_files: dict): try: with open(file_path, encoding="utf-8") as f: - if "node_modules" in s_fp: - if not s_fp.endswith(".json") and not s_fp.endswith(".d.ts"): - continue + if "node_modules" in s_fp and not s_fp.endswith(".json") and not s_fp.endswith(".d.ts"): + continue content = f.read() fs_files[str(file_path)] = content except (UnicodeDecodeError, OSError): # Skip files that can't be read as text continue - def _patch_engine_source(self): + def _patch_engine_source(self) -> None: """MiniRacer does not support require and export, so we need to patch the engine source to remove them.""" logger.info("Patching engine source to remove require and export") patch_map = { @@ -149,11 +151,11 @@ def _patch_engine_source(self): for old, new in patch_map.items(): self.engine_source = self.engine_source.replace(old, new) - def get_return_type(self, node: "Editable") -> str | None: + def get_return_type(self, node: Editable) -> str | None: file_path = os.path.join(self.repo_path, node.filepath) try: return self.ctx.eval(f"type_script_analyzer.getFunctionAtPosition('{file_path}', {node.start_byte})") - except JSEvalException as e: + except JSEvalException: return None @@ -168,7 +170,7 @@ class NodeTypescriptEngine(TypescriptEngine): type_data: dict | None - def __init__(self, repo_path: str, base_path: str | None = None, dependency_manager: "DependencyManager | None" = None): + def __init__(self, repo_path: str, base_path: str | None = None, dependency_manager: DependencyManager | None = None) -> None: super().__init__(repo_path, base_path, dependency_manager) logger.info("Initializing NodeTypescriptEngine") self.type_data: dict | None = None @@ -187,7 +189,7 @@ def __init__(self, repo_path: str, base_path: str | None = None, dependency_mana msg = f"Typescript analyzer not found at {self.analyzer_path}" raise FileNotFoundError(msg) - def _start(self): + def _start(self) -> None: try: logger.info("Starting NodeTypescriptEngine") super()._start() @@ -239,7 +241,7 @@ def _start(self): self._error = e logger.error(f"Error starting NodeTypescriptEngine: {e}", exc_info=True) - def get_return_type(self, node: "Editable") -> str | None: + def get_return_type(self, node: Editable) -> str | None: file_path: str = os.path.join(self.repo_path, node.filepath) if not self.type_data: return None @@ -247,4 +249,4 @@ def get_return_type(self, node: "Editable") -> str | None: file_data: dict = codebase_data.get(file_path, {}) functions_data: dict = file_data.get("functions", {}) function_data: dict = functions_data.get(node.name, {}) - return function_data.get("returnType", None) + return function_data.get("returnType") diff --git a/src/codegen/sdk/typescript/external/ts_declassify/ts_declassify.py b/src/codegen/sdk/typescript/external/ts_declassify/ts_declassify.py index b196fdc53..a8e72b160 100644 --- a/src/codegen/sdk/typescript/external/ts_declassify/ts_declassify.py +++ b/src/codegen/sdk/typescript/external/ts_declassify/ts_declassify.py @@ -2,6 +2,7 @@ import os import shutil import subprocess +from typing import NoReturn from codegen.sdk.core.external.external_process import ExternalProcess @@ -9,7 +10,7 @@ class TSDeclassify(ExternalProcess): - def __init__(self, repo_path: str, base_path: str, working_dir: str = "/tmp/ts_declassify"): + def __init__(self, repo_path: str, base_path: str, working_dir: str = "/tmp/ts_declassify") -> None: super().__init__(repo_path, base_path) self.working_dir = working_dir @@ -18,7 +19,7 @@ def __init__(self, repo_path: str, base_path: str, working_dir: str = "/tmp/ts_d msg = "NodeJS or npm is not installed" raise RuntimeError(msg) - def _start(self): + def _start(self) -> None: try: logger.info("Installing ts-declassify...") @@ -54,13 +55,13 @@ def _start(self): except Exception as e: self._error = e logger.exception(f"Error installing ts-declassify: {e}") - raise e + raise - def reparse(self): + def reparse(self) -> NoReturn: msg = "TSDeclassify does not support reparse" raise NotImplementedError(msg) - def declassify(self, source: str, filename: str = "file.tsx", error_on_failure: bool = True): + def declassify(self, source: str, filename: str = "file.tsx", error_on_failure: bool = True) -> str: assert self.ready(), "TSDeclassify is not ready" try: diff --git a/src/codegen/sdk/typescript/external/typescript_analyzer/package-lock.json b/src/codegen/sdk/typescript/external/typescript_analyzer/package-lock.json new file mode 100644 index 000000000..66f60088e --- /dev/null +++ b/src/codegen/sdk/typescript/external/typescript_analyzer/package-lock.json @@ -0,0 +1,828 @@ +{ + "name": "typescript-analyzer", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "typescript-analyzer", + "version": "1.0.0", + "dependencies": { + "typescript": "^5.0.0", + "yargs": "^17.7.2" + }, + "devDependencies": { + "@rollup/plugin-commonjs": "^28.0.0", + "@rollup/plugin-json": "^6.0.0", + "@rollup/plugin-node-resolve": "^15.0.0", + "@rollup/plugin-typescript": "^12.0.0", + "@rollup/plugin-virtual": "^3.0.2", + "@types/node": "^22.0.0", + "@types/yargs": "^17.0.32", + "rollup": "^4.9.0", + "ts-node": "^10.9.1", + "tslib": "^2.6.0" + } + }, + "node_modules/@cspotcode/source-map-support": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz", + "integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/trace-mapping": "0.3.9" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.9", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz", + "integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.0.3", + "@jridgewell/sourcemap-codec": "^1.4.10" + } + }, + "node_modules/@rollup/plugin-commonjs": { + "version": "28.0.2", + "resolved": "https://registry.npmjs.org/@rollup/plugin-commonjs/-/plugin-commonjs-28.0.2.tgz", + "integrity": "sha512-BEFI2EDqzl+vA1rl97IDRZ61AIwGH093d9nz8+dThxJNH8oSoB7MjWvPCX3dkaK1/RCJ/1v/R1XB15FuSs0fQw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@rollup/pluginutils": "^5.0.1", + "commondir": "^1.0.1", + "estree-walker": "^2.0.2", + "fdir": "^6.2.0", + "is-reference": "1.2.1", + "magic-string": "^0.30.3", + "picomatch": "^4.0.2" + }, + "engines": { + "node": ">=16.0.0 || 14 >= 14.17" + }, + "peerDependencies": { + "rollup": "^2.68.0||^3.0.0||^4.0.0" + }, + "peerDependenciesMeta": { + "rollup": { + "optional": true + } + } + }, + "node_modules/@rollup/plugin-json": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/@rollup/plugin-json/-/plugin-json-6.1.0.tgz", + "integrity": "sha512-EGI2te5ENk1coGeADSIwZ7G2Q8CJS2sF120T7jLw4xFw9n7wIOXHo+kIYRAoVpJAN+kmqZSoO3Fp4JtoNF4ReA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@rollup/pluginutils": "^5.1.0" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "rollup": "^1.20.0||^2.0.0||^3.0.0||^4.0.0" + }, + "peerDependenciesMeta": { + "rollup": { + "optional": true + } + } + }, + "node_modules/@rollup/plugin-node-resolve": { + "version": "15.3.1", + "resolved": "https://registry.npmjs.org/@rollup/plugin-node-resolve/-/plugin-node-resolve-15.3.1.tgz", + "integrity": "sha512-tgg6b91pAybXHJQMAAwW9VuWBO6Thi+q7BCNARLwSqlmsHz0XYURtGvh/AuwSADXSI4h/2uHbs7s4FzlZDGSGA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@rollup/pluginutils": "^5.0.1", + "@types/resolve": "1.20.2", + "deepmerge": "^4.2.2", + "is-module": "^1.0.0", + "resolve": "^1.22.1" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "rollup": "^2.78.0||^3.0.0||^4.0.0" + }, + "peerDependenciesMeta": { + "rollup": { + "optional": true + } + } + }, + "node_modules/@rollup/plugin-typescript": { + "version": "12.1.2", + "resolved": "https://registry.npmjs.org/@rollup/plugin-typescript/-/plugin-typescript-12.1.2.tgz", + "integrity": "sha512-cdtSp154H5sv637uMr1a8OTWB0L1SWDSm1rDGiyfcGcvQ6cuTs4MDk2BVEBGysUWago4OJN4EQZqOTl/QY3Jgg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@rollup/pluginutils": "^5.1.0", + "resolve": "^1.22.1" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "rollup": "^2.14.0||^3.0.0||^4.0.0", + "tslib": "*", + "typescript": ">=3.7.0" + }, + "peerDependenciesMeta": { + "rollup": { + "optional": true + }, + "tslib": { + "optional": true + } + } + }, + "node_modules/@rollup/plugin-virtual": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@rollup/plugin-virtual/-/plugin-virtual-3.0.2.tgz", + "integrity": "sha512-10monEYsBp3scM4/ND4LNH5Rxvh3e/cVeL3jWTgZ2SrQ+BmUoQcopVQvnaMcOnykb1VkxUFuDAN+0FnpTFRy2A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "rollup": "^1.20.0||^2.0.0||^3.0.0||^4.0.0" + }, + "peerDependenciesMeta": { + "rollup": { + "optional": true + } + } + }, + "node_modules/@rollup/pluginutils": { + "version": "5.1.4", + "resolved": "https://registry.npmjs.org/@rollup/pluginutils/-/pluginutils-5.1.4.tgz", + "integrity": "sha512-USm05zrsFxYLPdWWq+K3STlWiT/3ELn3RcV5hJMghpeAIhxfsUIg6mt12CBJBInWMV4VneoV7SfGv8xIwo2qNQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0", + "estree-walker": "^2.0.2", + "picomatch": "^4.0.2" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "rollup": "^1.20.0||^2.0.0||^3.0.0||^4.0.0" + }, + "peerDependenciesMeta": { + "rollup": { + "optional": true + } + } + }, + "node_modules/@rollup/rollup-darwin-arm64": { + "version": "4.32.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.32.1.tgz", + "integrity": "sha512-zCpKHioQ9KgZToFp5Wvz6zaWbMzYQ2LJHQ+QixDKq52KKrF65ueu6Af4hLlLWHjX1Wf/0G5kSJM9PySW9IrvHA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@tsconfig/node10": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.11.tgz", + "integrity": "sha512-DcRjDCujK/kCk/cUe8Xz8ZSpm8mS3mNNpta+jGCA6USEDfktlNvm1+IuZ9eTcDbNk41BHwpHHeW+N1lKCz4zOw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@tsconfig/node12": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz", + "integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==", + "dev": true, + "license": "MIT" + }, + "node_modules/@tsconfig/node14": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz", + "integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==", + "dev": true, + "license": "MIT" + }, + "node_modules/@tsconfig/node16": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.4.tgz", + "integrity": "sha512-vxhUy4J8lyeyinH7Azl1pdd43GJhZH/tP2weN8TntQblOY+A0XbT8DJk1/oCPuOOyg/Ja757rG0CgHcWC8OfMA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/estree": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.6.tgz", + "integrity": "sha512-AYnb1nQyY49te+VRAVgmzfcgjYS91mY5P0TKUDCLEM+gNnA+3T6rWITXRLYCpahpqSQbN5cE+gHpnPyXjHWxcw==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "22.12.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.12.0.tgz", + "integrity": "sha512-Fll2FZ1riMjNmlmJOdAyY5pUbkftXslB5DgEzlIuNaiWhXd00FhWxVC/r4yV/4wBb9JfImTu+jiSvXTkJ7F/gA==", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~6.20.0" + } + }, + "node_modules/@types/resolve": { + "version": "1.20.2", + "resolved": "https://registry.npmjs.org/@types/resolve/-/resolve-1.20.2.tgz", + "integrity": "sha512-60BCwRFOZCQhDncwQdxxeOEEkbc5dIMccYLwbxsS4TUNeVECQ/pBJ0j09mrHOl/JJvpRPGwO9SvE4nR2Nb/a4Q==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/yargs": { + "version": "17.0.33", + "resolved": "https://registry.npmjs.org/@types/yargs/-/yargs-17.0.33.tgz", + "integrity": "sha512-WpxBCKWPLr4xSsHgz511rFJAM+wS28w2zEO1QDNY5zM/S8ok70NNfztH0xwhqKyaK0OHCbN98LDAZuy1ctxDkA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/yargs-parser": "*" + } + }, + "node_modules/@types/yargs-parser": { + "version": "21.0.3", + "resolved": "https://registry.npmjs.org/@types/yargs-parser/-/yargs-parser-21.0.3.tgz", + "integrity": "sha512-I4q9QU9MQv4oEOz4tAHJtNz1cwuLxn2F3xcc2iV5WdqLPpUnj30aUuxt1mAxYTG+oe8CZMV/+6rU4S4gRDzqtQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/acorn": { + "version": "8.14.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz", + "integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==", + "dev": true, + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-walk": { + "version": "8.3.4", + "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.3.4.tgz", + "integrity": "sha512-ueEepnujpqee2o5aIYnvHU6C0A42MNdsIDeqy5BydrkuC5R1ZuUFnm27EeFJGoEHJQgn3uleRvmTXaJgfXbt4g==", + "dev": true, + "license": "MIT", + "dependencies": { + "acorn": "^8.11.0" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/arg": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz", + "integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==", + "dev": true, + "license": "MIT" + }, + "node_modules/cliui": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz", + "integrity": "sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==", + "license": "ISC", + "dependencies": { + "string-width": "^4.2.0", + "strip-ansi": "^6.0.1", + "wrap-ansi": "^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "license": "MIT" + }, + "node_modules/commondir": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/commondir/-/commondir-1.0.1.tgz", + "integrity": "sha512-W9pAhw0ja1Edb5GVdIF1mjZw/ASI0AlShXM83UUGe2DVr5TdAPEA1OA8m/g8zWp9x6On7gqufY+FatDbC3MDQg==", + "dev": true, + "license": "MIT" + }, + "node_modules/create-require": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz", + "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/deepmerge": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz", + "integrity": "sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/diff": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz", + "integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.3.1" + } + }, + "node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT" + }, + "node_modules/escalade": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/estree-walker": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", + "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", + "dev": true, + "license": "MIT" + }, + "node_modules/fdir": { + "version": "6.4.3", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.3.tgz", + "integrity": "sha512-PMXmW2y1hDDfTSRc9gaXIuCCRpuoz3Kaz8cUelp3smouvfT632ozg2vrT6lJsHKKOF59YLbOGfAWGUcKEfRMQw==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "picomatch": "^3 || ^4" + }, + "peerDependenciesMeta": { + "picomatch": { + "optional": true + } + } + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-caller-file": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", + "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", + "license": "ISC", + "engines": { + "node": "6.* || 8.* || >= 10.*" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/is-core-module": { + "version": "2.16.1", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.16.1.tgz", + "integrity": "sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==", + "dev": true, + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/is-module": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-module/-/is-module-1.0.0.tgz", + "integrity": "sha512-51ypPSPCoTEIN9dy5Oy+h4pShgJmPCygKfyRCISBI+JoWT/2oJvK8QPxmwv7b/p239jXrm9M1mlQbyKJ5A152g==", + "dev": true, + "license": "MIT" + }, + "node_modules/is-reference": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/is-reference/-/is-reference-1.2.1.tgz", + "integrity": "sha512-U82MsXXiFIrjCK4otLT+o2NA2Cd2g5MLoOVXUZjIOhLurrRxpEXzI8O0KZHr3IjLvlAH1kTPYSuqer5T9ZVBKQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "*" + } + }, + "node_modules/magic-string": { + "version": "0.30.17", + "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.17.tgz", + "integrity": "sha512-sNPKHvyjVf7gyjwS4xGTaW/mCnF8wnjtifKBEhxfZ7E/S8tQ0rssrwGNn6q8JH/ohItJfSQp9mBtQYuTlH5QnA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0" + } + }, + "node_modules/make-error": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz", + "integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==", + "dev": true, + "license": "ISC" + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "dev": true, + "license": "MIT" + }, + "node_modules/picomatch": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.2.tgz", + "integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/require-directory": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz", + "integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/resolve": { + "version": "1.22.10", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.10.tgz", + "integrity": "sha512-NPRy+/ncIMeDlTAsuqwKIiferiawhefFJtkNSW0qZJEqMEb+qBt/77B/jGeeek+F0uOeN05CDa6HXbbIgtVX4w==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-core-module": "^2.16.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/rollup": { + "version": "4.32.1", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.32.1.tgz", + "integrity": "sha512-z+aeEsOeEa3mEbS1Tjl6sAZ8NE3+AalQz1RJGj81M+fizusbdDMoEJwdJNHfaB40Scr4qNu+welOfes7maKonA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "1.0.6" + }, + "bin": { + "rollup": "dist/bin/rollup" + }, + "engines": { + "node": ">=18.0.0", + "npm": ">=8.0.0" + }, + "optionalDependencies": { + "@rollup/rollup-android-arm-eabi": "4.32.1", + "@rollup/rollup-android-arm64": "4.32.1", + "@rollup/rollup-darwin-arm64": "4.32.1", + "@rollup/rollup-darwin-x64": "4.32.1", + "@rollup/rollup-freebsd-arm64": "4.32.1", + "@rollup/rollup-freebsd-x64": "4.32.1", + "@rollup/rollup-linux-arm-gnueabihf": "4.32.1", + "@rollup/rollup-linux-arm-musleabihf": "4.32.1", + "@rollup/rollup-linux-arm64-gnu": "4.32.1", + "@rollup/rollup-linux-arm64-musl": "4.32.1", + "@rollup/rollup-linux-loongarch64-gnu": "4.32.1", + "@rollup/rollup-linux-powerpc64le-gnu": "4.32.1", + "@rollup/rollup-linux-riscv64-gnu": "4.32.1", + "@rollup/rollup-linux-s390x-gnu": "4.32.1", + "@rollup/rollup-linux-x64-gnu": "4.32.1", + "@rollup/rollup-linux-x64-musl": "4.32.1", + "@rollup/rollup-win32-arm64-msvc": "4.32.1", + "@rollup/rollup-win32-ia32-msvc": "4.32.1", + "@rollup/rollup-win32-x64-msvc": "4.32.1", + "fsevents": "~2.3.2" + } + }, + "node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/ts-node": { + "version": "10.9.2", + "resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.2.tgz", + "integrity": "sha512-f0FFpIdcHgn8zcPSbf1dRevwt047YMnaiJM3u2w2RewrB+fob/zePZcrOyQoLMMO7aBIddLcQIEK5dYjkLnGrQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@cspotcode/source-map-support": "^0.8.0", + "@tsconfig/node10": "^1.0.7", + "@tsconfig/node12": "^1.0.7", + "@tsconfig/node14": "^1.0.0", + "@tsconfig/node16": "^1.0.2", + "acorn": "^8.4.1", + "acorn-walk": "^8.1.1", + "arg": "^4.1.0", + "create-require": "^1.1.0", + "diff": "^4.0.1", + "make-error": "^1.1.1", + "v8-compile-cache-lib": "^3.0.1", + "yn": "3.1.1" + }, + "bin": { + "ts-node": "dist/bin.js", + "ts-node-cwd": "dist/bin-cwd.js", + "ts-node-esm": "dist/bin-esm.js", + "ts-node-script": "dist/bin-script.js", + "ts-node-transpile-only": "dist/bin-transpile.js", + "ts-script": "dist/bin-script-deprecated.js" + }, + "peerDependencies": { + "@swc/core": ">=1.2.50", + "@swc/wasm": ">=1.2.50", + "@types/node": "*", + "typescript": ">=2.7" + }, + "peerDependenciesMeta": { + "@swc/core": { + "optional": true + }, + "@swc/wasm": { + "optional": true + } + } + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true, + "license": "0BSD" + }, + "node_modules/typescript": { + "version": "5.7.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.3.tgz", + "integrity": "sha512-84MVSjMEHP+FQRPy3pX9sTVV/INIex71s9TL2Gm5FG/WG1SqXeKyZ0k7/blY/4FdOzI12CBy1vGc4og/eus0fw==", + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.20.0.tgz", + "integrity": "sha512-Ny6QZ2Nju20vw1SRHe3d9jVu6gJ+4e3+MMpqu7pqE5HT6WsTSlce++GQmK5UXS8mzV8DSYHrQH+Xrf2jVcuKNg==", + "dev": true, + "license": "MIT" + }, + "node_modules/v8-compile-cache-lib": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", + "integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==", + "dev": true, + "license": "MIT" + }, + "node_modules/wrap-ansi": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/y18n": { + "version": "5.0.8", + "resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz", + "integrity": "sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==", + "license": "ISC", + "engines": { + "node": ">=10" + } + }, + "node_modules/yargs": { + "version": "17.7.2", + "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", + "integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==", + "license": "MIT", + "dependencies": { + "cliui": "^8.0.1", + "escalade": "^3.1.1", + "get-caller-file": "^2.0.5", + "require-directory": "^2.1.1", + "string-width": "^4.2.3", + "y18n": "^5.0.5", + "yargs-parser": "^21.1.1" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/yargs-parser": { + "version": "21.1.1", + "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", + "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/yn": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz", + "integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + } + } +} diff --git a/src/codegen/sdk/typescript/file.py b/src/codegen/sdk/typescript/file.py index 6ba3d2d7a..cc52a48ca 100644 --- a/src/codegen/sdk/typescript/file.py +++ b/src/codegen/sdk/typescript/file.py @@ -391,9 +391,7 @@ def valid_import_names(self) -> dict[str, Symbol | TSImport]: valid_export_names = {} if len(self.default_exports) == 1: valid_export_names["default"] = self.default_exports[0] - for export in self.exports: - for name, dest in export.names: - valid_export_names[name] = dest + valid_export_names.update({name: dest for export in self.exports for name, dest in export.names}) return valid_export_names #################################################################################################################### diff --git a/src/codegen/sdk/typescript/statements/switch_case.py b/src/codegen/sdk/typescript/statements/switch_case.py index 01a49d72c..1e93fdc67 100644 --- a/src/codegen/sdk/typescript/statements/switch_case.py +++ b/src/codegen/sdk/typescript/statements/switch_case.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from codegen.sdk.codebase.codebase_graph import CodebaseGraph - from src.codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement + from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement @ts_apidoc diff --git a/src/codegen/sdk/typescript/symbol_groups/comment_group.py b/src/codegen/sdk/typescript/symbol_groups/comment_group.py index b15be09dc..809fe1bc1 100644 --- a/src/codegen/sdk/typescript/symbol_groups/comment_group.py +++ b/src/codegen/sdk/typescript/symbol_groups/comment_group.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self from codegen.sdk.core.symbol_groups.comment_group import CommentGroup from codegen.sdk.typescript.statements.comment import TSComment, TSCommentType @@ -12,7 +12,7 @@ @ts_apidoc class TSCommentGroup(CommentGroup): - """A group of related symbols that represent a comment or docstring in TypeScript + """A group of related symbols that represent a comment or docstring in TypeScript. For example: ``` @@ -62,7 +62,7 @@ def _get_sibbling_comments(symbol: TSSymbol) -> list[TSComment]: @classmethod @noapidoc - def from_symbol_comments(cls, symbol: TSSymbol): + def from_symbol_comments(cls, symbol: TSSymbol) -> Self | None: comment_nodes = cls._get_sibbling_comments(symbol) if not comment_nodes: return None @@ -70,7 +70,7 @@ def from_symbol_comments(cls, symbol: TSSymbol): @classmethod @noapidoc - def from_symbol_inline_comments(cls, symbol: TSSymbol): + def from_symbol_inline_comments(cls, symbol: TSSymbol) -> Self | None: # Locate the body that contains the comment nodes current_node = symbol.ts_node parent_node = symbol.ts_node.parent @@ -89,12 +89,11 @@ def from_symbol_inline_comments(cls, symbol: TSSymbol): comment_nodes = [] # Check if there are any comments after the function node - if function_index + 1 < len(parent_node.children): - if parent_node.children[function_index + 1].type == "comment": - # Check if the comment is on the same line - if parent_node.children[function_index].end_point[0] == parent_node.children[function_index + 1].start_point[0]: - comment = TSComment.from_code_block(parent_node.children[function_index + 1], symbol) - comment_nodes.append(comment) + if function_index + 1 < len(parent_node.children) and parent_node.children[function_index + 1].type == "comment": + # Check if the comment is on the same line + if parent_node.children[function_index].end_point[0] == parent_node.children[function_index + 1].start_point[0]: + comment = TSComment.from_code_block(parent_node.children[function_index + 1], symbol) + comment_nodes.append(comment) if not comment_nodes: return None @@ -104,7 +103,7 @@ def from_symbol_inline_comments(cls, symbol: TSSymbol): @classmethod @noapidoc def from_docstring(cls, symbol: TSSymbol) -> TSCommentGroup | None: - """Returns the docstring of the function""" + """Returns the docstring of the function.""" comment_nodes = cls._get_sibbling_comments(symbol) if not comment_nodes: return None @@ -116,7 +115,7 @@ def from_docstring(cls, symbol: TSSymbol) -> TSCommentGroup | None: @classmethod @noapidoc - def from_comment_nodes(cls, comment_nodes: list[TSComment], symbol: TSSymbol): + def from_comment_nodes(cls, comment_nodes: list[TSComment], symbol: TSSymbol) -> Self | None: if not comment_nodes: return None diff --git a/src/codegen/sdk/typescript/symbol_groups/dict.py b/src/codegen/sdk/typescript/symbol_groups/dict.py index 78f37cfcc..b6d7748d3 100644 --- a/src/codegen/sdk/typescript/symbol_groups/dict.py +++ b/src/codegen/sdk/typescript/symbol_groups/dict.py @@ -1,20 +1,22 @@ +from __future__ import annotations + import logging from typing import TYPE_CHECKING, Self, TypeVar, override -from tree_sitter import Node as TSNode - from codegen.sdk.core.autocommit import writer from codegen.sdk.core.expressions import Expression from codegen.sdk.core.expressions.string import String from codegen.sdk.core.interfaces.editable import Editable from codegen.sdk.core.interfaces.has_attribute import HasAttribute -from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.core.symbol_groups.dict import Dict, Pair from codegen.sdk.extensions.autocommit import reader from codegen.shared.decorators.docs import apidoc, noapidoc, ts_apidoc if TYPE_CHECKING: + from tree_sitter import Node as TSNode + from codegen.sdk.codebase.codebase_graph import CodebaseGraph + from codegen.sdk.core.node_id_factory import NodeId Parent = TypeVar("Parent", bound="Editable") TExpression = TypeVar("TExpression", bound=Expression) @@ -37,7 +39,7 @@ class TSPair(Pair): shorthand: bool - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent) -> None: super().__init__(ts_node, file_node_id, G, parent) self.shorthand = ts_node.type == "shorthand_property_identifier" @@ -62,7 +64,7 @@ def _get_key_value(self) -> tuple[Expression[Self] | None, Expression[Self] | No @writer def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -> None: - """Reduces an editable to the following condition""" + """Reduces an editable to the following condition.""" if self.shorthand and node == self.value: # Object shorthand self.parent[self.key.source] = self.G.node_classes.bool_conversion[bool_condition] @@ -72,9 +74,9 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - @apidoc class TSDict(Dict, HasAttribute): - """A typescript dict object. You can use standard operations to operate on this dict (IE len, del, set, get, etc)""" + """A typescript dict object. You can use standard operations to operate on this dict (IE len, del, set, get, etc).""" - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: "CodebaseGraph", parent: Parent, delimiter: str = ",", pair_type: type[Pair] = TSPair) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, parent: Parent, delimiter: str = ",", pair_type: type[Pair] = TSPair) -> None: super().__init__(ts_node, file_node_id, G, parent, delimiter=delimiter, pair_type=pair_type) def __getitem__(self, __key: str) -> TExpression: @@ -85,15 +87,13 @@ def __getitem__(self, __key: str) -> TExpression: if isinstance(pair.key, String): if pair.key.content == str(__key): pair_match = pair - elif pair.key is not None: - if pair.key.source == str(__key): - pair_match = pair + elif pair.key is not None and pair.key.source == str(__key): + pair_match = pair if pair_match: if pair_match.value is not None: return pair_match.value - else: - return pair_match.key + return pair_match.key msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}" raise KeyError(msg) @@ -106,9 +106,8 @@ def __setitem__(self, __key: str, __value: TExpression) -> None: if isinstance(pair.key, String): if pair.key.content == str(__key): pair_match = pair - elif pair.key is not None: - if pair.key.source == str(__key): - pair_match = pair + elif pair.key is not None and pair.key.source == str(__key): + pair_match = pair if pair_match: # CASE: {a: b} @@ -118,11 +117,10 @@ def __setitem__(self, __key: str, __value: TExpression) -> None: else: pair.value.edit(f"{new_value}") # CASE: {a} + elif __key == new_value: + pair_match.edit(f"{__key}") else: - if __key == new_value: - pair_match.edit(f"{__key}") - else: - pair_match.edit(f"{__key}: {new_value}") + pair_match.edit(f"{__key}: {new_value}") break # CASE: {} else: @@ -140,5 +138,5 @@ def __setitem__(self, __key: str, __value: TExpression) -> None: @reader @noapidoc @override - def resolve_attribute(self, name: str) -> "Expression | None": + def resolve_attribute(self, name: str) -> Expression | None: return self.get(name, None) diff --git a/src/codegen/sdk/typescript/ts_config.py b/src/codegen/sdk/typescript/ts_config.py index 806621c5c..670138338 100644 --- a/src/codegen/sdk/typescript/ts_config.py +++ b/src/codegen/sdk/typescript/ts_config.py @@ -50,7 +50,7 @@ class TSConfig: # when computing the import resolution. _import_optimization_enabled: bool = False - def __init__(self, config_file: File, config_parser: "TSConfigParser"): + def __init__(self, config_file: File, config_parser: "TSConfigParser") -> None: self.config_file = config_file self.config_parser = config_parser # Try to parse the config file as JSON5. Fallback to empty dict if it fails. @@ -64,10 +64,10 @@ def __init__(self, config_file: File, config_parser: "TSConfigParser"): # Precompute the base config, base url, paths, and references self._precompute_config_values() - def __repr__(self): + def __repr__(self) -> str: return f"TSConfig({self.config_file.filepath})" - def _precompute_config_values(self): + def _precompute_config_values(self) -> None: """Precomputes the base config, base url, paths, and references.""" # Precompute the base config self._base_config = None @@ -138,7 +138,7 @@ def _precompute_config_values(self): self._references = [*self_references] # MAYBE add base references here? This breaks the reference chain though. self._self_references = self_references - def _precompute_import_aliases(self): + def _precompute_import_aliases(self) -> None: """Precomputes the import aliases.""" if self._computed_path_import_aliases: return diff --git a/src/codegen/shared/compilation/string_to_code.py b/src/codegen/shared/compilation/string_to_code.py index 84da113a1..3a600363b 100644 --- a/src/codegen/shared/compilation/string_to_code.py +++ b/src/codegen/shared/compilation/string_to_code.py @@ -59,7 +59,7 @@ def closure_func() -> Callable[[Any], None]: _line_offset = line_offset # Wrap the func for better tracing - def wrapped_func(*args, **kwargs): + def wrapped_func(*args, **kwargs) -> None: """Wraps the user code to capture and format exceptions + grab locals""" try: linecache.cache[""] = (len(_func_str), None, _func_str.splitlines(True), "") diff --git a/src/codegen/shared/exceptions/control_flow.py b/src/codegen/shared/exceptions/control_flow.py index ba26ffcaf..f5888b553 100644 --- a/src/codegen/shared/exceptions/control_flow.py +++ b/src/codegen/shared/exceptions/control_flow.py @@ -5,7 +5,7 @@ class StopCodemodException(Exception): threshold: int | None = None - def __init__(self, message: str | None = None, threshold: int | None = None): + def __init__(self, message: str | None = None, threshold: int | None = None) -> None: super().__init__(message) self.threshold = threshold diff --git a/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py b/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py index ab93b59ea..b292ec171 100644 --- a/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py +++ b/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py @@ -35,7 +35,7 @@ class ChangeJSXElementName(Codemod, Skill): language = ProgrammingLanguage.TYPESCRIPT @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase): + def execute(self, codebase: Codebase) -> None: # Grab the NewName component PrivateRoutesContainer = codebase.get_symbol("PrivateRoutesContainer", optional=True) if PrivateRoutesContainer is None or not PrivateRoutesContainer.is_jsx: diff --git a/src/codemods/canonical/enum_mover/enum_mover.py b/src/codemods/canonical/enum_mover/enum_mover.py index 2a6522fa1..8bf4e3bb2 100644 --- a/src/codemods/canonical/enum_mover/enum_mover.py +++ b/src/codemods/canonical/enum_mover/enum_mover.py @@ -24,7 +24,7 @@ class EnumMover(Codemod, Skill): language = ProgrammingLanguage.PYTHON @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: CodebaseType): + def execute(self, codebase: CodebaseType) -> None: # Iterate over all classes in the codebase for cls in codebase.classes: # Check if the class is a subclass of Enum diff --git a/src/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py b/src/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py index 7ac12fed6..606c61ae6 100644 --- a/src/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py +++ b/src/codemods/canonical/move_functions_to_new_file/move_functions_to_new_file.py @@ -28,7 +28,7 @@ class MoveFunctionsToNewFile(Codemod, Skill): language = ProgrammingLanguage.PYTHON @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase): + def execute(self, codebase: Codebase) -> None: # Create a new file for storing the functions that contain pylsp util functions new_file: SourceFile = codebase.create_file("pylsp/pylsp_shared.py", "") for file in codebase.files: diff --git a/src/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py b/src/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py index 945f316c5..4206a05d7 100644 --- a/src/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py +++ b/src/codemods/canonical/openapi_add_response_none/openapi_add_response_none.py @@ -43,7 +43,7 @@ def get(self): language = ProgrammingLanguage.PYTHON @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase): + def execute(self, codebase: Codebase) -> None: def get_response_decorators(method: Symbol) -> list[Decorator]: """Returns a list of decorators that contain the string '.response' in the source code""" return [d for d in method.decorators if ".response" in d.source] diff --git a/src/codemods/canonical/split_file/split_file.py b/src/codemods/canonical/split_file/split_file.py index a61dc6401..52e06c825 100644 --- a/src/codemods/canonical/split_file/split_file.py +++ b/src/codemods/canonical/split_file/split_file.py @@ -20,7 +20,7 @@ class SplitFile(Codemod, Skill): language = ProgrammingLanguage.PYTHON @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: Codebase): + def execute(self, codebase: Codebase) -> None: # Grab large file to split file = codebase.get_file("sqlglot/optimizer/scope.py", optional=True) if file is None: diff --git a/src/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py b/src/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py index 85e757dba..85e511feb 100644 --- a/src/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py +++ b/src/codemods/canonical/split_file_and_rename_symbols/split_file_and_rename_symbols.py @@ -25,7 +25,7 @@ class SplitFileAndRenameSymbols(Codemod, Skill): language = ProgrammingLanguage.PYTHON @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.PYTHON) - def execute(self, codebase: CodebaseType): + def execute(self, codebase: CodebaseType) -> None: # Get file to split up source_file = codebase.get_file("redash/models/types.py", optional=True) if source_file is None: diff --git a/src/codemods/canonical/split_large_files/split_large_files.py b/src/codemods/canonical/split_large_files/split_large_files.py index 6ec618ac2..e8600dc88 100644 --- a/src/codemods/canonical/split_large_files/split_large_files.py +++ b/src/codemods/canonical/split_large_files/split_large_files.py @@ -21,7 +21,7 @@ class SplitLargeFiles(Codemod, Skill): language = ProgrammingLanguage.TYPESCRIPT @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def execute(self, codebase: Codebase): + def execute(self, codebase: Codebase) -> None: # Define constants for maximum lengths MAX_FILE_LENGTH = 500 MAX_SYMBOL_LENGTH = 50 diff --git a/src/codemods/codemod.py b/src/codemods/codemod.py index 53af94b21..1fdea1cae 100644 --- a/src/codemods/codemod.py +++ b/src/codemods/codemod.py @@ -4,7 +4,7 @@ class Codemod: execute: Callable | None = None - def __init__(self, name: str | None = None, execute: Callable | None = None, *args, **kwargs): + def __init__(self, name: str | None = None, execute: Callable | None = None, *args, **kwargs) -> None: self.name = name if execute: self.execute = execute diff --git a/tests/conftest.py b/tests/conftest.py index df179d306..65680b60f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ def find_dirs_to_ignore(start_dir, prefix): return dirs_to_ignore -def pytest_addoption(parser): +def pytest_addoption(parser) -> None: parser.addoption( "--size", action="append", @@ -68,7 +68,7 @@ def pytest_addoption(parser): # content of conftest.py -def pytest_configure(config): +def pytest_configure(config) -> None: worker_id = os.environ.get("PYTEST_XDIST_WORKER") if worker_id is not None: os.makedirs("build/logs", exist_ok=True) @@ -101,7 +101,7 @@ def pytest_runtest_makereport(item, call): @pytest.fixture(autouse=True) -def skip_lfs_tests(request): +def skip_lfs_tests(request) -> None: """Skip tests that depend on git LFS files if they haven't been pulled""" # Lets not run if we are in CI if os.getenv("CI") == "true" or os.getenv("CIRCLECI") == "true": @@ -111,7 +111,11 @@ def skip_lfs_tests(request): test_path = Path(request.module.__file__) # Only run for integration tests - if not str(test_path).startswith(str(Path.cwd() / "tests" / "integration")): + try: + cwd = Path.cwd() + except FileNotFoundError: + return + if not str(test_path).startswith(str(cwd / "tests" / "integration")): return try: diff --git a/tests/integration/codegen/git/clients/test_github_client_factory.py b/tests/integration/codegen/git/clients/test_github_client_factory.py index 668faae0e..00e81d8cf 100644 --- a/tests/integration/codegen/git/clients/test_github_client_factory.py +++ b/tests/integration/codegen/git/clients/test_github_client_factory.py @@ -2,7 +2,7 @@ from codegen.git.schemas.github import GithubType -def test_github_client_factory_create_from_token_no_token(): +def test_github_client_factory_create_from_token_no_token() -> None: github_client = GithubClientFactory.create_from_token(github_type=GithubType.Github) assert github_client.base_url == "https://api.github.com" repo = github_client.read_client.get_repo("python-lsp/python-lsp-server") @@ -10,7 +10,7 @@ def test_github_client_factory_create_from_token_no_token(): assert repo.name == "python-lsp-server" -def test_github_client_factory_create_from_repo(repo_config): +def test_github_client_factory_create_from_repo(repo_config) -> None: github_client = GithubClientFactory.create_from_repo(repo_config=repo_config, github_type=GithubType.Github) repo = github_client.read_client.get_repo("codegen-sh/Kevin-s-Adventure-Game") assert repo.full_name == "codegen-sh/Kevin-s-Adventure-Game" diff --git a/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py b/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py index 99604e0d4..f3f2ee97a 100644 --- a/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py +++ b/tests/integration/codegen/git/repo_operator/test_remote_repo_operator.py @@ -18,7 +18,7 @@ def op(repo_config, request, tmpdir): @pytest.mark.parametrize("op", shallow_options, ids=lambda x: f"shallow={x}", indirect=True) @patch("codegen.git.clients.github_client.Github") -def test_checkout_branch(mock_git_client, op: RemoteRepoOperator): +def test_checkout_branch(mock_git_client, op: RemoteRepoOperator) -> None: mock_git_client.return_value = Github("test_token", "https://api.github.com") op.pull_repo() op.checkout_commit(op.head_commit) @@ -40,7 +40,7 @@ def test_checkout_branch(mock_git_client, op: RemoteRepoOperator): @pytest.mark.parametrize("op", [True], ids=lambda x: f"shallow={x}", indirect=True) @patch("codegen.git.clients.github_client.Github") -def test_checkout_branch_local_already_checked_out(mock_git_client, op: RemoteRepoOperator): +def test_checkout_branch_local_already_checked_out(mock_git_client, op: RemoteRepoOperator) -> None: mock_git_client.return_value = Github("test_token", "https://api.github.com") op.checkout_commit(op.head_commit) @@ -60,7 +60,7 @@ def test_checkout_branch_local_already_checked_out(mock_git_client, op: RemoteRe @pytest.mark.parametrize("op", [True], ids=lambda x: f"shallow={x}", indirect=True) @patch("codegen.git.clients.github_client.Github") -def test_checkout_branch_remote_already_checked_out_resets_branch(mock_git_client, op: RemoteRepoOperator): +def test_checkout_branch_remote_already_checked_out_resets_branch(mock_git_client, op: RemoteRepoOperator) -> None: mock_git_client.return_value = Github("test_token", "https://api.github.com") original_commit_head = op.head_commit diff --git a/tests/integration/codegen/test_imports.py b/tests/integration/codegen/test_imports.py index e2af00f64..7dc889965 100644 --- a/tests/integration/codegen/test_imports.py +++ b/tests/integration/codegen/test_imports.py @@ -2,10 +2,10 @@ from codegen import Codebase -def test_codegen_imports(): +def test_codegen_imports() -> None: # Test decorated function @codegen.function(name="sample_codemod") - def run(codebase): + def run(codebase) -> None: pass # Test class diff --git a/tests/integration/codemod/test_diffs.py b/tests/integration/codemod/test_diffs.py index 5745caef2..d0f59443b 100644 --- a/tests/integration/codemod/test_diffs.py +++ b/tests/integration/codemod/test_diffs.py @@ -21,7 +21,7 @@ def test_codemods_diffs(_codebase: Codebase, expected: Path) -> None: verify_diffs(_codebase) -def verify_diffs(_codebase): +def verify_diffs(_codebase) -> None: modified = gather_modified_files(_codebase) diffs = [DiffLite.from_git_diff(diff) for diff in _codebase.get_diffs()] _codebase.G.apply_diffs(diffs) diff --git a/tests/shared/codemod/codebase_comparison_utils.py b/tests/shared/codemod/codebase_comparison_utils.py index 85659d307..e2f8781f1 100644 --- a/tests/shared/codemod/codebase_comparison_utils.py +++ b/tests/shared/codemod/codebase_comparison_utils.py @@ -149,7 +149,7 @@ def capture_single_file_stats(files_dir: Path, files, expected_dir: Path, extens return diff -def capture_modified_file_stats(left_lines: list[str], right_lines: list[str], diff_stats: dict): +def capture_modified_file_stats(left_lines: list[str], right_lines: list[str], diff_stats: dict) -> None: """Captures stats on the differences between two complementary files, a file that exists in both the expected and modified codebases. """ diff --git a/tests/shared/codemod/commands.py b/tests/shared/codemod/commands.py index b7d2984dc..74bacc08c 100644 --- a/tests/shared/codemod/commands.py +++ b/tests/shared/codemod/commands.py @@ -23,24 +23,24 @@ def codemod() -> None: @codemod.command() @click.option("--extra-repos", is_flag=True) -def generate_cases(extra_repos: bool = False): +def generate_cases(extra_repos: bool = False) -> None: """Generate cases for codemod tests. Very slow""" repos = find_repos(extra_repos=extra_repos) for codemod in find_codemods(): - for repo_name, repo in repos.items(): + for repo_name in repos.keys(): (codemod.test_dir / f"test_{repo_name}").mkdir(parents=True, exist_ok=True) _generate_diffs(extra_repos=extra_repos) _clean_diffs(aggressive=True) -def _generate_diffs(extra_repos: bool = False): +def _generate_diffs(extra_repos: bool = False) -> None: """Generate diffs for codemod tests""" os.system(f"pytest tests/integration/codemod/test_codemods.py::test_codemods_cloned_repos --size small --extra-repos={str(extra_repos).lower()} -n auto --snapshot-update") os.system(f"pytest tests/integration/codemod/test_codemods.py::test_codemods_cloned_repos --size large --extra-repos={str(extra_repos).lower()} -n auto --snapshot-update") @codemod.command() -def generate_diffs(): +def generate_diffs() -> None: """Generate diffs for codemod tests""" _generate_diffs() _clean_diffs() @@ -67,7 +67,7 @@ def gather_repos_per_codemod() -> dict[str, dict[tuple[Size, bool], list[ClonedR MAX_CASES = {Size.Small: 1, Size.Large: 1} -def _clean_diffs(aggressive: bool = False): +def _clean_diffs(aggressive: bool = False) -> None: repos = {**find_repos(extra_repos=True), **find_repos(extra_repos=False)} for test_case in find_codemod_test_cases(repos=repos): @@ -93,7 +93,7 @@ def _clean_diffs(aggressive: bool = False): @codemod.command() @click.option("--aggressive", is_flag=True) -def clean_diffs(aggressive: bool = False): +def clean_diffs(aggressive: bool = False) -> None: _clean_diffs(aggressive) @@ -204,7 +204,7 @@ def _fetch_and_store_codemod(repo_id: str, url: str, cli_api_key: str) -> tuple[ @codemod.command() @click.option("--cli-api-key", required=True, help="API key for authentication") -def fetch_verified_codemods(cli_api_key: str): +def fetch_verified_codemods(cli_api_key: str) -> None: """Fetch codemods for all repos in REPO_ID_TO_URL and save to JSON files.""" VERIFIED_CODEMOD_DATA_DIR.mkdir(parents=True, exist_ok=True) repos_to_commits: dict[str, list[dict]] = {} diff --git a/tests/shared/codemod/verified_codemod_utils.py b/tests/shared/codemod/verified_codemod_utils.py index af5984158..0fb5ea40c 100644 --- a/tests/shared/codemod/verified_codemod_utils.py +++ b/tests/shared/codemod/verified_codemod_utils.py @@ -83,7 +83,7 @@ def anonymized_name(self) -> str: class CodemodAPI: - def __init__(self, api_key: str | None = None, modal_prefix: str = "https://codegen-sh"): + def __init__(self, api_key: str | None = None, modal_prefix: str = "https://codegen-sh") -> None: self.api_key = api_key self.modal_prefix = modal_prefix self.get_codemods_url = f"{self.modal_prefix}--{GET_CODEMODS_URL_SUFFIX}" diff --git a/tests/shared/mocks/mock_ai_helper.py b/tests/shared/mocks/mock_ai_helper.py index b163cfeb3..203d977b9 100644 --- a/tests/shared/mocks/mock_ai_helper.py +++ b/tests/shared/mocks/mock_ai_helper.py @@ -31,5 +31,5 @@ def llm_query_functions_with_retry(self, model: str, messages: list, functions: def llm_query_functions(self, model: str, messages: list, functions: list[dict], max_tokens: int | None = None) -> None: pass - def llm_response_to_json(response) -> str: + def llm_response_to_json(self) -> str: pass diff --git a/tests/shared/skills/decorators.py b/tests/shared/skills/decorators.py index c6acec556..4b1179e3f 100644 --- a/tests/shared/skills/decorators.py +++ b/tests/shared/skills/decorators.py @@ -77,7 +77,7 @@ def properly_named_function(name: str, language: ProgrammingLanguage): return False -def skill_func_name_error(name: str, language: ProgrammingLanguage): +def skill_func_name_error(name: str, language: ProgrammingLanguage) -> str: if language == ProgrammingLanguage.PYTHON: return f"Function name must be 'skill_func', 'python_skill_func' or 'execute' for Python skills. Got {name}." elif language == ProgrammingLanguage.TYPESCRIPT: diff --git a/tests/shared/skills/skill.py b/tests/shared/skills/skill.py index 17b695faa..06525d4a7 100644 --- a/tests/shared/skills/skill.py +++ b/tests/shared/skills/skill.py @@ -29,7 +29,7 @@ def python_skill_func(codebase: PyCodebaseType) -> callable: ... def typescript_skill_func(codebase: TSCodebaseType) -> callable: ... @staticmethod - def skill_func(codebase: CodebaseType): ... + def skill_func(codebase: CodebaseType) -> None: ... def execute(self, codebase: Codebase) -> None: ... diff --git a/tests/shared/skills/skill_implementation.py b/tests/shared/skills/skill_implementation.py index dc83a7049..c2d3c2d79 100644 --- a/tests/shared/skills/skill_implementation.py +++ b/tests/shared/skills/skill_implementation.py @@ -37,7 +37,7 @@ def __init__( si_id: int | None = None, from_app: bool = False, external: bool = False, - ): + ) -> None: self.name = name or skill_func.__name__ self.language = language self.test_cases = [] if skip_test else test_cases @@ -122,6 +122,6 @@ def function_body(self) -> str: def __call__(self, codebase: CodebaseType): self._skill_func(codebase) - def __str__(self): + def __str__(self) -> str: source = inspect.getsource(self._skill_func) return textwrap.dedent(source).strip() diff --git a/tests/shared/skills/utils.py b/tests/shared/skills/utils.py index 342d7d061..f206e43f4 100644 --- a/tests/shared/skills/utils.py +++ b/tests/shared/skills/utils.py @@ -17,7 +17,7 @@ CODEMODS_PATH: Path = REPO_ROOT_PATH / "src" / "codemods" / "canonical" -def import_skills_from(path: Path, module_prefix: str): +def import_skills_from(path: Path, module_prefix: str) -> None: for file in path.rglob("*.py"): relative_path = file.relative_to(path) if "evaluation" in relative_path.parts or "__init__" in file.name: @@ -26,7 +26,7 @@ def import_skills_from(path: Path, module_prefix: str): importlib.import_module(module) -def import_all_skills(): +def import_all_skills() -> None: import_skills_from(SKILLS_PATH, "tests.unit.skills.implementations") import_skills_from(CODEMODS_PATH, "codemods.canonical") diff --git a/tests/shared/utils/recursion.py b/tests/shared/utils/recursion.py index c816a4fab..2c5e4fbae 100644 --- a/tests/shared/utils/recursion.py +++ b/tests/shared/utils/recursion.py @@ -5,7 +5,7 @@ logger = logging.getLogger(__name__) -def set_recursion_limit(): +def set_recursion_limit() -> None: sys.setrecursionlimit(10**9) if sys.platform == "linux": logger.info(f"Setting stack limit to {resource.RLIM_INFINITY}") diff --git a/tests/unit/codegen/cli/conftest.py b/tests/unit/codegen/cli/conftest.py index 3263638ca..2d6421397 100644 --- a/tests/unit/codegen/cli/conftest.py +++ b/tests/unit/codegen/cli/conftest.py @@ -32,4 +32,7 @@ def initialized_repo(sample_repository: Path, runner: CliRunner): subprocess.run(["git", "add", "."], cwd=sample_repository, check=True) subprocess.run(["git", "commit", "-m", "Initialize codegen"], cwd=sample_repository, check=True) yield sample_repository - shutil.rmtree(sample_repository) + try: + shutil.rmtree(sample_repository) + except FileNotFoundError: + pass diff --git a/tests/unit/codegen/cli/test_reset.py b/tests/unit/codegen/cli/test_reset.py index cbbbb19fd..ae0c235e5 100644 --- a/tests/unit/codegen/cli/test_reset.py +++ b/tests/unit/codegen/cli/test_reset.py @@ -47,7 +47,7 @@ def committed_repo(initialized_repo: Path, committed_state: dict[str, str]) -> P return initialized_repo -def setup_repo_state(repo_path: Path, state: dict[str, str]): +def setup_repo_state(repo_path: Path, state: dict[str, str]) -> None: """Helper to set up files in the repo""" for filepath, content in state.items(): file_path = repo_path / filepath @@ -93,7 +93,7 @@ def get_git_status(repo_path: Path) -> tuple[set[str], set[str], set[str]]: def verify_git_state( repo_path: Path, expected_staged: set[str] | None = None, expected_modified: set[str] | None = None, expected_untracked: set[str] | None = None, rename_pairs: list[tuple[str, str]] | None = None -): +) -> None: """Verify git status matches expected state""" if rename_pairs is not None: for old_path, new_path in rename_pairs: @@ -114,7 +114,7 @@ def verify_git_state( assert untracked == expected_untracked, f"Untracked files mismatch.\nExpected: {expected_untracked}\nActual: {untracked}" -def verify_repo_state(repo_path: Path, expected_content: dict[str, str | None]): +def verify_repo_state(repo_path: Path, expected_content: dict[str, str | None]) -> None: """Verify file contents in repo""" for path, content in expected_content.items(): file_path = repo_path / path @@ -265,7 +265,7 @@ def create_test_case( ), ], ) -def test_reset(committed_repo: Path, committed_state: dict[str, str], test_case: ResetTestCase, runner: CliRunner): +def test_reset(committed_repo: Path, committed_state: dict[str, str], test_case: ResetTestCase, runner: CliRunner) -> None: """Test reset command with various scenarios""" # Get test case from factory function if needed if callable(test_case): @@ -299,7 +299,7 @@ def test_reset(committed_repo: Path, committed_state: dict[str, str], test_case: ) -def test_reset_with_mixed_states(committed_repo: Path, committed_state: dict[str, str], runner: CliRunner): +def test_reset_with_mixed_states(committed_repo: Path, committed_state: dict[str, str], runner: CliRunner) -> None: """Test reset with a mix of staged, unstaged, and untracked changes""" # 1. Staged modifications staged_changes = { @@ -348,7 +348,7 @@ def test_reset_with_mixed_states(committed_repo: Path, committed_state: dict[str ) -def test_reset_with_mixed_renames(committed_repo: Path, committed_state: dict[str, str], runner: CliRunner): +def test_reset_with_mixed_renames(committed_repo: Path, committed_state: dict[str, str], runner: CliRunner) -> None: """Test reset with a mix of staged and unstaged renames""" # 1. Staged rename staged_changes = { diff --git a/tests/unit/codegen/git/clients/test_git_repo_client.py b/tests/unit/codegen/git/clients/test_git_repo_client.py index c729dc9a6..d2b2cace0 100644 --- a/tests/unit/codegen/git/clients/test_git_repo_client.py +++ b/tests/unit/codegen/git/clients/test_git_repo_client.py @@ -7,7 +7,7 @@ @patch("codegen.git.clients.git_repo_client.GithubClientFactory") def test_delete_branch_default( mock_github_client_factory, -): +) -> None: git_repo_client = GitRepoClient(repo_config=MagicMock(), access_scope=GithubScope.WRITE) git_repo_client.read_client = MagicMock(default_branch="default-branch") git_repo_client.delete_branch(branch_name="default-branch") @@ -18,7 +18,7 @@ def test_delete_branch_default( @patch("codegen.git.clients.git_repo_client.GithubClientFactory") def test_delete_branch_non_default_branch( mock_github_client_factory, -): +) -> None: git_repo_client = GitRepoClient(repo_config=MagicMock(), access_scope=GithubScope.WRITE) git_repo_client.read_client = MagicMock(default_branch="default-branch") mock_ref = MagicMock() @@ -30,7 +30,7 @@ def test_delete_branch_non_default_branch( @patch("codegen.git.clients.git_repo_client.GithubClientFactory") def test_delete_branch_cannot_write_branch( mock_github_client_factory, -): +) -> None: git_repo_client = GitRepoClient(repo_config=MagicMock(), access_scope=GithubScope.WRITE) git_repo_client.read_client = MagicMock(default_branch="default-branch") git_repo_client.delete_branch(branch_name="not-default-branch") diff --git a/tests/unit/codegen/git/schemas/test_github.py b/tests/unit/codegen/git/schemas/test_github.py index 26d2b4d3a..a875b58ca 100644 --- a/tests/unit/codegen/git/schemas/test_github.py +++ b/tests/unit/codegen/git/schemas/test_github.py @@ -1,6 +1,6 @@ from codegen.git.schemas.github import GithubType -def test_github_type_base_url(): +def test_github_type_base_url() -> None: assert GithubType.Github.base_url == "https://github.com" assert GithubType.GithubEnterprise.base_url == "https://github.codegen.app" diff --git a/tests/unit/codegen/gscli/test_cli.py b/tests/unit/codegen/gscli/test_cli.py index 9c3356d96..96c11bdae 100644 --- a/tests/unit/codegen/gscli/test_cli.py +++ b/tests/unit/codegen/gscli/test_cli.py @@ -1,2 +1,2 @@ -def test_cli(): +def test_cli() -> None: import codegen.gscli # noqa: F401 diff --git a/tests/unit/codegen/runner/sandbox/test_executor.py b/tests/unit/codegen/runner/sandbox/test_executor.py index 0c1fa5f7a..60094b061 100644 --- a/tests/unit/codegen/runner/sandbox/test_executor.py +++ b/tests/unit/codegen/runner/sandbox/test_executor.py @@ -17,7 +17,7 @@ @pytest.mark.asyncio -async def test_execute_func_pass_in_codemod_context_takes_priority(executor: SandboxExecutor): +async def test_execute_func_pass_in_codemod_context_takes_priority(executor: SandboxExecutor) -> None: codemod_context = CodemodContext( CODEMOD_LINK="http://codegen.sh/codemod/5678", ) @@ -85,7 +85,7 @@ async def test_execute_func_pass_in_codemod_context_takes_priority(executor: San @pytest.mark.asyncio -async def test_run_max_preview_time_exceeded_sets_observation_meta(executor: SandboxExecutor): +async def test_run_max_preview_time_exceeded_sets_observation_meta(executor: SandboxExecutor) -> None: mock_source = """ codebase.files[0].edit("a = 2") """ @@ -97,7 +97,7 @@ async def test_run_max_preview_time_exceeded_sets_observation_meta(executor: San @pytest.mark.asyncio -async def test_run_max_ai_requests_error_sets_observation_meta(executor: SandboxExecutor): +async def test_run_max_ai_requests_error_sets_observation_meta(executor: SandboxExecutor) -> None: mock_source = """ codebase.ai("tell me a joke") """ @@ -109,7 +109,7 @@ async def test_run_max_ai_requests_error_sets_observation_meta(executor: Sandbox @pytest.mark.asyncio -async def test_run_max_transactions_exceeded_sets_observation_meta(executor: SandboxExecutor): +async def test_run_max_transactions_exceeded_sets_observation_meta(executor: SandboxExecutor) -> None: mock_source = """ codebase.files[0].edit("a = 2") """ @@ -122,7 +122,7 @@ async def test_run_max_transactions_exceeded_sets_observation_meta(executor: San @pytest.mark.asyncio -async def test_find_flag_groups_with_subdirectories(executor: SandboxExecutor): +async def test_find_flag_groups_with_subdirectories(executor: SandboxExecutor) -> None: groups = await executor.find_flag_groups( code_flags=[ CodeFlag( @@ -151,7 +151,7 @@ async def test_find_flag_groups_with_subdirectories(executor: SandboxExecutor): @pytest.mark.asyncio -async def test_find_flag_groups_with_group_by(executor: SandboxExecutor): +async def test_find_flag_groups_with_group_by(executor: SandboxExecutor) -> None: groups = await executor.find_flag_groups( code_flags=[ CodeFlag( @@ -185,7 +185,7 @@ async def test_find_flag_groups_with_group_by(executor: SandboxExecutor): @pytest.mark.asyncio @pytest.mark.parametrize("codebase", [121], indirect=True) -async def test_find_flag_groups_with_group_by_app(executor: SandboxExecutor): +async def test_find_flag_groups_with_group_by_app(executor: SandboxExecutor) -> None: groups = await executor.find_flag_groups( code_flags=[ CodeFlag( @@ -213,7 +213,7 @@ async def test_find_flag_groups_with_group_by_app(executor: SandboxExecutor): @pytest.mark.skip(reason="TODO: add max_prs as part of find_flag_groups") @pytest.mark.asyncio -async def test_find_flag_groups_with_max_prs(executor: SandboxExecutor): +async def test_find_flag_groups_with_max_prs(executor: SandboxExecutor) -> None: groups = await executor.find_flag_groups( code_flags=[ CodeFlag( diff --git a/tests/unit/codegen/runner/sandbox/test_runner.py b/tests/unit/codegen/runner/sandbox/test_runner.py index 0abd5c557..b1fb3a81d 100644 --- a/tests/unit/codegen/runner/sandbox/test_runner.py +++ b/tests/unit/codegen/runner/sandbox/test_runner.py @@ -7,7 +7,7 @@ @pytest.mark.asyncio @patch("codegen.runner.sandbox.executor.SandboxExecutor") -async def test_sandbox_runner_warmup_builds_graph(mock_executor, runner: SandboxRunner): +async def test_sandbox_runner_warmup_builds_graph(mock_executor, runner: SandboxRunner) -> None: await runner.warmup() assert runner.codebase.files assert len(runner.codebase.files) == 1 @@ -15,7 +15,7 @@ async def test_sandbox_runner_warmup_builds_graph(mock_executor, runner: Sandbox @pytest.mark.asyncio @patch("codegen.runner.sandbox.runner.SandboxRunner._build_graph") -async def test_sandbox_runner_warmup_builds_graph_throws(mock_build_graph, runner: SandboxRunner): +async def test_sandbox_runner_warmup_builds_graph_throws(mock_build_graph, runner: SandboxRunner) -> None: mock_build_graph.side_effect = Exception("Test exception") with pytest.raises(Exception): @@ -25,7 +25,7 @@ async def test_sandbox_runner_warmup_builds_graph_throws(mock_build_graph, runne @pytest.mark.asyncio @patch("codegen.runner.sandbox.runner.logger") @patch("codegen.runner.sandbox.runner.SandboxExecutor") -async def test_sandbox_runner_warmup_logs_repo_id(mock_executor, mock_logger, runner: SandboxRunner): +async def test_sandbox_runner_warmup_logs_repo_id(mock_executor, mock_logger, runner: SandboxRunner) -> None: await runner.warmup() assert runner.codebase.files assert len(runner.codebase.files) == 1 @@ -35,7 +35,7 @@ async def test_sandbox_runner_warmup_logs_repo_id(mock_executor, mock_logger, ru @pytest.mark.asyncio @patch("codegen.runner.sandbox.runner.SandboxExecutor") -async def test_sandbox_runner_warmup_starts_with_default_branch(mock_executor, runner: SandboxRunner): +async def test_sandbox_runner_warmup_starts_with_default_branch(mock_executor, runner: SandboxRunner) -> None: await runner.warmup() # assert True is returned # assert len(runner.codebase._op.git_cli.branches) == 1 TODO: fix GHA creating master and main branch assert not runner.codebase._op.git_cli.head.is_detached @@ -47,7 +47,7 @@ async def test_sandbox_runner_warmup_starts_with_default_branch(mock_executor, r @patch("codegen.runner.sandbox.runner.logger") @patch("codegen.runner.sandbox.runner.SandboxExecutor") @patch("codegen.sdk.core.codebase.Codebase.default_branch", new_callable=PropertyMock) -async def test_sandbox_runner_reset_runner_deletes_branches(mock_branch, mock_executor, mock_logger, runner: SandboxRunner): +async def test_sandbox_runner_reset_runner_deletes_branches(mock_branch, mock_executor, mock_logger, runner: SandboxRunner) -> None: mock_branch.return_value = "main" await runner.warmup() num_branches = len(runner.codebase._op.git_cli.heads) # TODO: fix GHA creating master and main branch and assert the len is 1 at the start diff --git a/tests/unit/codegen/runner/utils/test_branch_name.py b/tests/unit/codegen/runner/utils/test_branch_name.py index 6b3d807a5..7aae328db 100644 --- a/tests/unit/codegen/runner/utils/test_branch_name.py +++ b/tests/unit/codegen/runner/utils/test_branch_name.py @@ -3,13 +3,13 @@ from codegen.runner.utils.branch_name import get_head_branch_name -def test_get_head_branch_name_no_group(): +def test_get_head_branch_name_no_group() -> None: codemod = MagicMock(epic_id=123, version_id=456, run_id=789) branch_name = get_head_branch_name(codemod=codemod, group=None) assert branch_name == "codegen-codemod-123-version-456-run-789-group-0" -def test_get_head_branch_name_with_group(): +def test_get_head_branch_name_with_group() -> None: codemod = MagicMock(epic_id=123, version_id=456, run_id=789) group = MagicMock(id=2) branch_name = get_head_branch_name(codemod=codemod, group=group) diff --git a/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py b/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py index 3a8ebf689..f9f74fb1a 100644 --- a/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py +++ b/tests/unit/codegen/sdk/benchmark/codebase/test_codebase_reset.py @@ -22,13 +22,13 @@ def setup_codebase(num_files: int, extension: str, tmp_path: Path): return codebase, files -def reset_codebase(codebase: Codebase): +def reset_codebase(codebase: Codebase) -> None: codebase.reset() @pytest.mark.benchmark(group="sdk-benchmark", min_time=1, max_time=5, disable_gc=True) @pytest.mark.parametrize("extension", ["txt", "py"]) -def test_codebase_reset_stress_test(extension: str, tmp_path, benchmark): +def test_codebase_reset_stress_test(extension: str, tmp_path, benchmark) -> None: def setup(): codebase, _ = setup_codebase(NUM_FILES, extension, tmp_path) return ((codebase,), {}) @@ -38,7 +38,7 @@ def setup(): @pytest.mark.timeout(5, func_only=True) @pytest.mark.parametrize("extension", ["txt", "py"]) -def test_codebase_reset_correctness(extension: str, tmp_path): +def test_codebase_reset_correctness(extension: str, tmp_path) -> None: codebase, files = setup_codebase(NUM_FILES, extension, tmp_path) codebase.reset() for file, original_content in files.items(): diff --git a/tests/unit/codegen/sdk/code_generation/test_codemod_writer_decorators.py b/tests/unit/codegen/sdk/code_generation/test_codemod_writer_decorators.py index 39e2fc20b..4e6b695c8 100644 --- a/tests/unit/codegen/sdk/code_generation/test_codemod_writer_decorators.py +++ b/tests/unit/codegen/sdk/code_generation/test_codemod_writer_decorators.py @@ -8,7 +8,7 @@ @pytest.mark.skip("broken after repo split, fix or move over") -def test_get_documented_objects(): +def test_get_documented_objects() -> None: """Test that importing everything in codegen-sdk doesn't invoke any functionality (incase someone leaves actual functionality at the top level of a file). diff --git a/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_reset.py b/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_reset.py index a084ae143..c61a26308 100644 --- a/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_reset.py +++ b/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_reset.py @@ -10,7 +10,7 @@ ], indirect=["original", "expected"], ) -def test_codebase_reset(codebase: Codebase, assert_expected, tmp_path): +def test_codebase_reset(codebase: Codebase, assert_expected, tmp_path) -> None: # External change should be preserved (tmp_path / "a.py").write_text("b") # Programmatic change should be reset @@ -27,7 +27,7 @@ def test_codebase_reset(codebase: Codebase, assert_expected, tmp_path): ], indirect=["original", "expected"], ) -def test_codebase_reset_external_changes(codebase: Codebase, assert_expected): +def test_codebase_reset_external_changes(codebase: Codebase, assert_expected) -> None: # External change should be preserved codebase.get_file("a.py").path.write_text("b") codebase.commit() @@ -42,7 +42,7 @@ def test_codebase_reset_external_changes(codebase: Codebase, assert_expected): ], indirect=["original", "expected"], ) -def test_codebase_reset_manual_file_add(codebase: Codebase, assert_expected, tmp_path): +def test_codebase_reset_manual_file_add(codebase: Codebase, assert_expected, tmp_path) -> None: # Manually create a new file - should be preserved new_file = tmp_path / "new.py" new_file.write_text("new content") @@ -60,7 +60,7 @@ def test_codebase_reset_manual_file_add(codebase: Codebase, assert_expected, tmp ], indirect=["original", "expected"], ) -def test_codebase_reset_manual_file_delete(codebase: Codebase, assert_expected): +def test_codebase_reset_manual_file_delete(codebase: Codebase, assert_expected) -> None: # Manual deletion should be preserved codebase.get_file("b.py").path.unlink() # Programmatic change should be reset @@ -77,7 +77,7 @@ def test_codebase_reset_manual_file_delete(codebase: Codebase, assert_expected): ], indirect=["original", "expected"], ) -def test_codebase_reset_manual_file_rename(codebase: Codebase, tmp_path, assert_expected): +def test_codebase_reset_manual_file_rename(codebase: Codebase, tmp_path, assert_expected) -> None: # Manual rename should be preserved old_path = codebase.get_file("old.py").path new_path = tmp_path / "new.py" @@ -104,7 +104,7 @@ def test_codebase_reset_manual_file_rename(codebase: Codebase, tmp_path, assert_ ], indirect=["original", "expected"], ) -def test_codebase_reset_nested_directories(codebase: Codebase, assert_expected, tmp_path): +def test_codebase_reset_nested_directories(codebase: Codebase, assert_expected, tmp_path) -> None: """Test reset with nested directory structure.""" # External changes should be preserved (tmp_path / "src/main.py").write_text("def main():\n print('modified')") @@ -133,7 +133,7 @@ def test_codebase_reset_nested_directories(codebase: Codebase, assert_expected, ], indirect=["original", "expected"], ) -def test_codebase_reset_mixed_content(codebase: Codebase, assert_expected, tmp_path): +def test_codebase_reset_mixed_content(codebase: Codebase, assert_expected, tmp_path) -> None: """Test reset with different types of file content.""" # External changes should be preserved (tmp_path / "config.json").write_text('{\n "debug": false,\n "env": "prod"\n}') @@ -187,7 +187,7 @@ def reset(self): ], indirect=["original", "expected"], ) -def test_codebase_reset_large_file(codebase: Codebase, assert_expected): +def test_codebase_reset_large_file(codebase: Codebase, assert_expected) -> None: """Test reset with a larger file containing multiple methods.""" codebase.get_file("module.py").edit("""class ModifiedClass: def __init__(self): @@ -204,7 +204,7 @@ def __init__(self): ], indirect=["original", "expected"], ) -def test_codebase_reset_preserves_external_changes(codebase: Codebase, assert_expected, tmp_path): +def test_codebase_reset_preserves_external_changes(codebase: Codebase, assert_expected, tmp_path) -> None: # Make external changes to existing file src_dir = tmp_path / "src" src_dir.mkdir(exist_ok=True) @@ -229,7 +229,7 @@ def test_codebase_reset_preserves_external_changes(codebase: Codebase, assert_ex ], indirect=["original", "expected"], ) -def test_codebase_reset_mixed_changes(codebase: Codebase, assert_expected, tmp_path): +def test_codebase_reset_mixed_changes(codebase: Codebase, assert_expected, tmp_path) -> None: # Make programmatic change that should be reset codebase.get_file("src/utils.py").edit("def helper():\n return None") @@ -250,7 +250,7 @@ def test_codebase_reset_mixed_changes(codebase: Codebase, assert_expected, tmp_p ], indirect=["original", "expected"], ) -def test_codebase_reset_nested_external_changes(codebase: Codebase, assert_expected, tmp_path): +def test_codebase_reset_nested_external_changes(codebase: Codebase, assert_expected, tmp_path) -> None: # Create nested directory structure with changes config_dir = tmp_path / "config" config_dir.mkdir(exist_ok=True) @@ -277,7 +277,7 @@ def test_codebase_reset_nested_external_changes(codebase: Codebase, assert_expec ], indirect=["original", "expected"], ) -def test_codebase_reset_multiple_programmatic_edits(codebase: Codebase, assert_expected): +def test_codebase_reset_multiple_programmatic_edits(codebase: Codebase, assert_expected) -> None: """Test reset after multiple programmatic edits to the same file.""" # Make multiple programmatic changes that should all be reset codebase.get_file("file.py").edit("first edit") @@ -302,7 +302,7 @@ def test_codebase_reset_multiple_programmatic_edits(codebase: Codebase, assert_e ], indirect=["original", "expected"], ) -def test_codebase_reset_interleaved_changes(codebase: Codebase, assert_expected): +def test_codebase_reset_interleaved_changes(codebase: Codebase, assert_expected) -> None: """Test reset with interleaved programmatic and external changes.""" # Interleave programmatic and external changes codebase.get_file("file.py").edit("def main():\n return 1") @@ -335,7 +335,7 @@ def method1(self): ], indirect=["original", "expected"], ) -def test_codebase_reset_complex_changes(codebase: Codebase, assert_expected): +def test_codebase_reset_complex_changes(codebase: Codebase, assert_expected) -> None: """Test reset with a mix of content additions, modifications, and external changes.""" # Make several programmatic changes for i in range(5): diff --git a/tests/unit/codegen/sdk/codebase/file_graph/test_symbol_parent_statement.py b/tests/unit/codegen/sdk/codebase/file_graph/test_symbol_parent_statement.py index 5eb6e77e2..017f1163d 100644 --- a/tests/unit/codegen/sdk/codebase/file_graph/test_symbol_parent_statement.py +++ b/tests/unit/codegen/sdk/codebase/file_graph/test_symbol_parent_statement.py @@ -6,7 +6,7 @@ from codegen.sdk.python import PyFile -def test_parent_statement_function(tmpdir): +def test_parent_statement_function(tmpdir) -> None: # language=python content = """ def foo(): @@ -17,7 +17,7 @@ def foo(): assert codebase.get_function("foo").parent_statement == file.code_block.statements[0] -def test_parent_statement_class(tmpdir): +def test_parent_statement_class(tmpdir) -> None: # language=python content = """ class Foo: @@ -28,7 +28,7 @@ class Foo: assert codebase.get_class("Foo").parent_statement == file.code_block.statements[0] -def test_parent_statement_assignment(tmpdir): +def test_parent_statement_assignment(tmpdir) -> None: # language=python content = """ foo = 1 @@ -38,7 +38,7 @@ def test_parent_statement_assignment(tmpdir): assert file.get_global_var("foo").parent_statement == file.code_block.statements[0] -def test_parent_statement_nested_symbols(tmpdir): +def test_parent_statement_nested_symbols(tmpdir) -> None: # language=python content = """ logger = get_logger() diff --git a/tests/unit/codegen/sdk/codebase/flagging/test_code_flag.py b/tests/unit/codegen/sdk/codebase/flagging/test_code_flag.py index d7beba6ec..1491360a6 100644 --- a/tests/unit/codegen/sdk/codebase/flagging/test_code_flag.py +++ b/tests/unit/codegen/sdk/codebase/flagging/test_code_flag.py @@ -2,7 +2,7 @@ from codegen.sdk.codebase.flagging.enums import MessageType -def test_code_flag_properties(tmpdir): +def test_code_flag_properties(tmpdir) -> None: # language=python content = """ class Foo: diff --git a/tests/unit/codegen/sdk/codebase/flagging/test_group_all.py b/tests/unit/codegen/sdk/codebase/flagging/test_group_all.py index 179a6ffb2..be07f8c93 100644 --- a/tests/unit/codegen/sdk/codebase/flagging/test_group_all.py +++ b/tests/unit/codegen/sdk/codebase/flagging/test_group_all.py @@ -4,7 +4,7 @@ from codegen.sdk.codebase.flagging.groupers.all_grouper import AllGrouper -def test_group_all(): +def test_group_all() -> None: flag1 = CodeFlag( symbol=MagicMock( file=MagicMock(filepath="test.py"), diff --git a/tests/unit/codegen/sdk/codebase/transaction_manager/test_transaction_mangager_checks.py b/tests/unit/codegen/sdk/codebase/transaction_manager/test_transaction_mangager_checks.py index 807559392..b3c30ae73 100644 --- a/tests/unit/codegen/sdk/codebase/transaction_manager/test_transaction_mangager_checks.py +++ b/tests/unit/codegen/sdk/codebase/transaction_manager/test_transaction_mangager_checks.py @@ -6,7 +6,7 @@ from codegen.shared.exceptions.control_flow import MaxPreviewTimeExceeded, MaxTransactionsExceeded -def test_check_max_preview_time_exceeded(tmpdir): +def test_check_max_preview_time_exceeded(tmpdir) -> None: with pytest.raises(MaxPreviewTimeExceeded) as exc_info: with get_codebase_session( tmpdir=tmpdir, @@ -20,7 +20,7 @@ def test_check_max_preview_time_exceeded(tmpdir): assert exc_info.value.threshold == 0 -def test_check_max_preview_time_exceeded_set_session_options(tmpdir): +def test_check_max_preview_time_exceeded_set_session_options(tmpdir) -> None: with get_codebase_session( tmpdir=tmpdir, files={"file_a.py": "a = 1", "file_b": "b = 1"}, @@ -34,7 +34,7 @@ def test_check_max_preview_time_exceeded_set_session_options(tmpdir): assert exc_info.value.threshold == 0 -def test_check_max_transactions_exceeded(tmpdir): +def test_check_max_transactions_exceeded(tmpdir) -> None: with pytest.raises(MaxTransactionsExceeded) as exc_info: with get_codebase_session( tmpdir=tmpdir, @@ -48,7 +48,7 @@ def test_check_max_transactions_exceeded(tmpdir): assert exc_info.value.threshold == 0 -def test_check_max_transactions_exceeded_set_session_options(tmpdir): +def test_check_max_transactions_exceeded_set_session_options(tmpdir) -> None: with get_codebase_session( tmpdir=tmpdir, files={"file_a.py": "a = 1", "file_b": "b = 1"}, diff --git a/tests/unit/codegen/sdk/conftest.py b/tests/unit/codegen/sdk/conftest.py index a9b12f9a5..03162cc4a 100644 --- a/tests/unit/codegen/sdk/conftest.py +++ b/tests/unit/codegen/sdk/conftest.py @@ -28,7 +28,7 @@ def codebase(tmp_path, original: dict[str, str], programming_language: Programmi @pytest.fixture def assert_expected(expected: dict[str, str], tmp_path): - def assert_expected(codebase: Codebase): + def assert_expected(codebase: Codebase) -> None: codebase.commit() for file in expected: assert tmp_path.joinpath(file).exists() diff --git a/tests/unit/codegen/sdk/python/codebase/test_codebase_raise_error.py b/tests/unit/codegen/sdk/python/codebase/test_codebase_raise_error.py index a183ea2ef..8aedd9ba3 100644 --- a/tests/unit/codegen/sdk/python/codebase/test_codebase_raise_error.py +++ b/tests/unit/codegen/sdk/python/codebase/test_codebase_raise_error.py @@ -4,7 +4,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_python_exports_not_supported(tmpdir): +def test_python_exports_not_supported(tmpdir) -> None: """Test that exports are not supported in Python codebases.""" # language=python content = """ diff --git a/tests/unit/codegen/sdk/python/codebase/test_control_flow.py b/tests/unit/codegen/sdk/python/codebase/test_control_flow.py index 589f2a2fd..4aaa7c5dc 100644 --- a/tests/unit/codegen/sdk/python/codebase/test_control_flow.py +++ b/tests/unit/codegen/sdk/python/codebase/test_control_flow.py @@ -1,3 +1,5 @@ +from typing import NoReturn + import pytest from codegen.sdk.codebase.config import SessionOptions @@ -7,13 +9,13 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_max_transactions_exceeded_reached_set_threshold(tmpdir): +def test_max_transactions_exceeded_reached_set_threshold(tmpdir) -> None: e = MaxTransactionsExceeded("test exception", threshold=1) assert str(e) == "test exception" assert e.threshold == 1 -def test_raise_max_transactions_exceeded_reached_no_threshold(tmpdir): +def test_raise_max_transactions_exceeded_reached_no_threshold(tmpdir) -> NoReturn: with pytest.raises(MaxTransactionsExceeded) as exc_info: msg = "test exception" raise MaxTransactionsExceeded(msg) @@ -21,7 +23,7 @@ def test_raise_max_transactions_exceeded_reached_no_threshold(tmpdir): assert exc_info.value.threshold is None -def test_raise_max_transactions_exceeded_reached_with_threshold(tmpdir): +def test_raise_max_transactions_exceeded_reached_with_threshold(tmpdir) -> NoReturn: with pytest.raises(MaxTransactionsExceeded) as exc_info: msg = "test exception" raise MaxTransactionsExceeded(msg, threshold=1) @@ -29,7 +31,7 @@ def test_raise_max_transactions_exceeded_reached_with_threshold(tmpdir): assert exc_info.value.threshold == 1 -def test_max_transactions_exceeded_reached(tmpdir): +def test_max_transactions_exceeded_reached(tmpdir) -> None: with pytest.raises(MaxTransactionsExceeded) as exc_info: with get_codebase_session( tmpdir=tmpdir, @@ -43,7 +45,7 @@ def test_max_transactions_exceeded_reached(tmpdir): assert exc_info.value.threshold == 1 -def test_max_transactions_exceeded_reached_should_still_commit(tmpdir): +def test_max_transactions_exceeded_reached_should_still_commit(tmpdir) -> None: with pytest.raises(MaxTransactionsExceeded) as exc_info: with get_codebase_session( tmpdir=tmpdir, @@ -59,7 +61,7 @@ def test_max_transactions_exceeded_reached_should_still_commit(tmpdir): assert exc_info.value.threshold == 1 -def test_max_preview_time_exceeded_reached(tmpdir): +def test_max_preview_time_exceeded_reached(tmpdir) -> None: with pytest.raises(MaxPreviewTimeExceeded) as exc_info: with get_codebase_session( tmpdir=tmpdir, @@ -73,7 +75,7 @@ def test_max_preview_time_exceeded_reached(tmpdir): assert exc_info.value.threshold == 0 -def test_max_ai_requests_error_reached(tmpdir): +def test_max_ai_requests_error_reached(tmpdir) -> None: with pytest.raises(MaxAIRequestsError) as exc_info: with get_codebase_session( tmpdir=tmpdir, diff --git a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_indent.py b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_indent.py index 4bd8a726f..13d38ea7e 100644 --- a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_indent.py +++ b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_indent.py @@ -7,7 +7,7 @@ from codegen.sdk.python import PyFile -def test_indent_left_once(tmpdir): +def test_indent_left_once(tmpdir) -> None: # language=python content = """ def foo(): @@ -32,7 +32,7 @@ def foo(): ) -def test_indent_left_exceed_limit(tmpdir): +def test_indent_left_exceed_limit(tmpdir) -> None: # language=python content = """ def foo(): @@ -57,7 +57,7 @@ def foo(): ) -def test_indent_right_once(tmpdir): +def test_indent_right_once(tmpdir) -> None: # language=python content = """ def foo(): @@ -82,7 +82,7 @@ def foo(): ) -def test_indent_right_multiple(tmpdir): +def test_indent_right_multiple(tmpdir) -> None: # language=python content = """ def foo(): @@ -107,7 +107,7 @@ def foo(): ) -def test_indent_zero(tmpdir): +def test_indent_zero(tmpdir) -> None: # language=python content = """ def foo(): @@ -122,7 +122,7 @@ def foo(): assert file.content == content -def test_indent_nested_code_blocks(tmpdir): +def test_indent_nested_code_blocks(tmpdir) -> None: # language=python content = """ def foo(): @@ -181,7 +181,7 @@ def foo(): ) -def test_indent_only_nested_block(tmpdir): +def test_indent_only_nested_block(tmpdir) -> None: # language=python content = """ def foo(): diff --git a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_insert.py b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_insert.py index 1b29bb754..d96496c04 100644 --- a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_insert.py +++ b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_insert.py @@ -1,7 +1,7 @@ from codegen.sdk.codebase.factory.get_session import get_codebase_session -def test_code_block_append_single_element(tmpdir): +def test_code_block_append_single_element(tmpdir) -> None: # language=python content = """ def foo(): @@ -27,7 +27,7 @@ def foo(): ) -def test_code_block_append_multiple_element(tmpdir): +def test_code_block_append_multiple_element(tmpdir) -> None: # language=python content = """ def foo(): @@ -55,7 +55,7 @@ def foo(): ) -def test_code_block_insert_single_element_beginning(tmpdir): +def test_code_block_insert_single_element_beginning(tmpdir) -> None: # language=python content = """ def foo(): @@ -81,7 +81,7 @@ def foo(): ) -def test_code_block_insert_multiple_element_beginning(tmpdir): +def test_code_block_insert_multiple_element_beginning(tmpdir) -> None: # language=python content = """ def foo(): @@ -109,7 +109,7 @@ def foo(): ) -def test_code_block_insert_single_element_middle(tmpdir): +def test_code_block_insert_single_element_middle(tmpdir) -> None: # language=python content = """ def foo(): @@ -135,7 +135,7 @@ def foo(): ) -def test_code_block_insert_multiple_element_middle(tmpdir): +def test_code_block_insert_multiple_element_middle(tmpdir) -> None: # language=python content = """ def foo(): @@ -163,7 +163,7 @@ def foo(): ) -def test_code_block_insert_single_element_end(tmpdir): +def test_code_block_insert_single_element_end(tmpdir) -> None: # language=python content = """ def foo(): @@ -189,7 +189,7 @@ def foo(): ) -def test_code_block_insert_multiple_element_end(tmpdir): +def test_code_block_insert_multiple_element_end(tmpdir) -> None: # language=python content = """ def foo(): @@ -217,7 +217,7 @@ def foo(): ) -def test_code_block_insert_multiple_out_of_order(tmpdir): +def test_code_block_insert_multiple_out_of_order(tmpdir) -> None: # language=python content = """ def foo(): diff --git a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_remove.py b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_remove.py index 5c58933eb..41a0d8349 100644 --- a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_remove.py +++ b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_remove.py @@ -1,7 +1,7 @@ from codegen.sdk.codebase.factory.get_session import get_codebase_session -def test_code_block_remove_multiple_element(tmpdir): +def test_code_block_remove_multiple_element(tmpdir) -> None: # language=python content = """ def foo(): @@ -25,7 +25,7 @@ def foo(): ) -def test_code_block_remove_insert_beginning(tmpdir): +def test_code_block_remove_insert_beginning(tmpdir) -> None: # language=python content = """ def foo(): @@ -51,7 +51,7 @@ def foo(): ) -def test_code_block_insert_remove_beginning(tmpdir): +def test_code_block_insert_remove_beginning(tmpdir) -> None: # language=python content = """ def foo(): @@ -77,7 +77,7 @@ def foo(): ) -def test_code_block_remove_insert_middle(tmpdir): +def test_code_block_remove_insert_middle(tmpdir) -> None: # language=python content = """ def foo(): @@ -103,7 +103,7 @@ def foo(): ) -def test_code_block_insert_remove_middle(tmpdir): +def test_code_block_insert_remove_middle(tmpdir) -> None: # language=python content = """ def foo(): diff --git a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_unwrap.py b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_unwrap.py index df9419975..8bf67b3eb 100644 --- a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_unwrap.py +++ b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_unwrap.py @@ -7,7 +7,7 @@ from codegen.sdk.python import PyFile -def test_unwrap_flat_statements(tmpdir): +def test_unwrap_flat_statements(tmpdir) -> None: # language=python content = """ def foo(): @@ -36,7 +36,7 @@ def foo(): ) -def test_unwrap_nested_statements(tmpdir): +def test_unwrap_nested_statements(tmpdir) -> None: # language=python content = """ def foo(): @@ -72,7 +72,7 @@ def foo(): ) -def test_unwrap_sandwiched_statements(tmpdir): +def test_unwrap_sandwiched_statements(tmpdir) -> None: # language=python content = """ def foo(): @@ -108,7 +108,7 @@ def foo(): ) -def test_unwrap_multiline_wrapper(tmpdir): +def test_unwrap_multiline_wrapper(tmpdir) -> None: # language=python content = """ def foo(): diff --git a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_wrap.py b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_wrap.py index 046be8ff5..5fbd475ed 100644 --- a/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_wrap.py +++ b/tests/unit/codegen/sdk/python/detached_symbols/code_block/test_code_block_wrap.py @@ -6,7 +6,7 @@ from codegen.sdk.python import PyFile -def test_wrap_with_with_statement(tmpdir): +def test_wrap_with_with_statement(tmpdir) -> None: # language=python content = """ def foo(a: bool): @@ -33,7 +33,7 @@ def foo(a: bool): ) -def test_wrap_with_function(tmpdir): +def test_wrap_with_function(tmpdir) -> None: # language=python content = """ if a: diff --git a/tests/unit/codegen/sdk/python/detached_symbols/function_call/test_function_calls_from_symbol.py b/tests/unit/codegen/sdk/python/detached_symbols/function_call/test_function_calls_from_symbol.py index 9dd786afa..e6dc799b2 100644 --- a/tests/unit/codegen/sdk/python/detached_symbols/function_call/test_function_calls_from_symbol.py +++ b/tests/unit/codegen/sdk/python/detached_symbols/function_call/test_function_calls_from_symbol.py @@ -10,7 +10,7 @@ from codegen.sdk.python.statements.if_block_statement import PyIfBlockStatement -def test_function_calls_from_file(tmpdir): +def test_function_calls_from_file(tmpdir) -> None: # language=python content = """ from some_file import x, y, z @@ -45,7 +45,7 @@ def random(): ] -def test_function_calls_from_class(tmpdir): +def test_function_calls_from_class(tmpdir) -> None: # language=python content = """ from some_file import bar @@ -86,7 +86,7 @@ def foo(self): assert third_call.parent.statement_type == StatementType.RETURN_STATEMENT -def test_function_calls_from_decorated_definitions(tmpdir): +def test_function_calls_from_decorated_definitions(tmpdir) -> None: # language=python content = """ @pytest.mark.parametrize("x", [1, 2, 3]) @@ -124,7 +124,7 @@ def __init__(self): @pytest.mark.xfail(reason="Broken by function call changes") -def test_function_calls_from_datatypes(tmpdir): +def test_function_calls_from_datatypes(tmpdir) -> None: # language=python content = """ def get_config(): @@ -168,7 +168,7 @@ def get_config(): ] -def test_function_calls_from_function_parameters(tmpdir): +def test_function_calls_from_function_parameters(tmpdir) -> None: # language=python content = """ # function parameters @@ -189,7 +189,7 @@ def greet(name=get_default_name()): ] -def test_function_calls_from_while_loop(tmpdir): +def test_function_calls_from_while_loop(tmpdir) -> None: # language=python content = """ # while loop conditions @@ -209,7 +209,7 @@ def test_function_calls_from_while_loop(tmpdir): ] -def test_function_calls_from_if_conditions(tmpdir): +def test_function_calls_from_if_conditions(tmpdir) -> None: # language=python content = """ # if conditions @@ -233,7 +233,7 @@ def test_function_calls_from_if_conditions(tmpdir): ] -def test_function_calls_from_with_clause(tmpdir): +def test_function_calls_from_with_clause(tmpdir) -> None: # language=python content = """ # with clause @@ -255,5 +255,5 @@ def test_function_calls_from_with_clause(tmpdir): @pytest.mark.skip(reason="Ellen: why is this test empty?") -def test_function_calls_from_function_calls(tmpdir): +def test_function_calls_from_function_calls(tmpdir) -> None: pass diff --git a/tests/unit/codegen/sdk/python/expressions/test_binary_expression.py b/tests/unit/codegen/sdk/python/expressions/test_binary_expression.py index 918d7d171..2725ffbfc 100644 --- a/tests/unit/codegen/sdk/python/expressions/test_binary_expression.py +++ b/tests/unit/codegen/sdk/python/expressions/test_binary_expression.py @@ -9,7 +9,7 @@ from codegen.sdk.python import PyFile -def test_all_binary_expression_types(tmpdir): +def test_all_binary_expression_types(tmpdir) -> None: # language=python content = """ a = 1 + 2 @@ -133,7 +133,7 @@ def test_all_binary_expression_types(tmpdir): assert file.get_global_var("w").value.operator.source == "in" -def test_chained_binary_expressions(tmpdir): +def test_chained_binary_expressions(tmpdir) -> None: # language=python content = """ a = 1 + 2 - 3 * 4 / 5 % 6 ** 7 // 8 # binary operators @@ -157,7 +157,7 @@ def test_chained_binary_expressions(tmpdir): @pytest.mark.skip(reason="CG-8883: Parenthesized expressions not implemented yet") -def test_chained_multiline_binary_expressions_using_parenthesis(tmpdir): +def test_chained_multiline_binary_expressions_using_parenthesis(tmpdir) -> None: # language=python content = """ a = (1 + 2 - 3 * 4 @@ -185,7 +185,7 @@ def test_chained_multiline_binary_expressions_using_parenthesis(tmpdir): assert [x.source for x in c.operators] == ["and", "or", "and", "or"] -def test_chained_multiline_binary_expressions_using_backslash(tmpdir): +def test_chained_multiline_binary_expressions_using_backslash(tmpdir) -> None: # language=python content = """ a = 1 + 2 - 3 * 4 \ @@ -214,7 +214,7 @@ def test_chained_multiline_binary_expressions_using_backslash(tmpdir): @pytest.mark.skip(reason="CG-8886: Mixed expression groups not implemented yet") -def test_chained_mixed_binary_expressions(tmpdir): +def test_chained_mixed_binary_expressions(tmpdir) -> None: # language=python content = """ a = 1 + 2 == True != False or True and False * 12 @@ -238,7 +238,7 @@ def test_chained_mixed_binary_expressions(tmpdir): @pytest.mark.skip(reason="CG-8883: Parenthesized expressions not implemented yet") -def test_chained_mixed_multiline_binary_expressions_with_parenthesis(tmpdir): +def test_chained_mixed_multiline_binary_expressions_with_parenthesis(tmpdir) -> None: # language=python content = """ a = (1 + 2 == True @@ -253,7 +253,7 @@ def test_chained_mixed_multiline_binary_expressions_with_parenthesis(tmpdir): @pytest.mark.skip(reason="CG-8886: Mixed expression groups not implemented yet") -def test_chained_mixed_multiline_binary_expressions_with_backslash(tmpdir): +def test_chained_mixed_multiline_binary_expressions_with_backslash(tmpdir) -> None: # language=python content = """ a = 1 + 2 == True \ diff --git a/tests/unit/codegen/sdk/python/expressions/test_binary_reduce_condition.py b/tests/unit/codegen/sdk/python/expressions/test_binary_reduce_condition.py index 1e63e0fd1..d3aef419c 100644 --- a/tests/unit/codegen/sdk/python/expressions/test_binary_reduce_condition.py +++ b/tests/unit/codegen/sdk/python/expressions/test_binary_reduce_condition.py @@ -7,7 +7,7 @@ from codegen.sdk.python.file import PyFile -def test_reduce_binary_simple(tmpdir): +def test_reduce_binary_simple(tmpdir) -> None: # language=python content = """ def foo(): @@ -31,7 +31,7 @@ def foo(): ) -def test_reduce_binary_complex_condition(tmpdir): +def test_reduce_binary_complex_condition(tmpdir) -> None: # language=python content = """ def bar(): @@ -58,7 +58,7 @@ def bar(): ) -def test_reduce_negation_condition(tmpdir): +def test_reduce_negation_condition(tmpdir) -> None: # language=python content = """ def baz(): @@ -84,7 +84,7 @@ def baz(): ) -def test_reduce_binary_with_string(tmpdir): +def test_reduce_binary_with_string(tmpdir) -> None: # language=python content = """ def qux(): @@ -108,7 +108,7 @@ def qux(): ) -def test_reduce_complex_condition_with_string(tmpdir): +def test_reduce_complex_condition_with_string(tmpdir) -> None: # language=python content = """ def quux(): diff --git a/tests/unit/codegen/sdk/python/expressions/test_boolean.py b/tests/unit/codegen/sdk/python/expressions/test_boolean.py index 51ac48ecb..cf3777914 100644 --- a/tests/unit/codegen/sdk/python/expressions/test_boolean.py +++ b/tests/unit/codegen/sdk/python/expressions/test_boolean.py @@ -5,7 +5,7 @@ from codegen.sdk.python.assignment import PyAssignment -def test_boolean_parse(tmpdir): +def test_boolean_parse(tmpdir) -> None: # language=python content = """ a = True diff --git a/tests/unit/codegen/sdk/python/expressions/test_number.py b/tests/unit/codegen/sdk/python/expressions/test_number.py index 1452f1477..30cc9cf5a 100644 --- a/tests/unit/codegen/sdk/python/expressions/test_number.py +++ b/tests/unit/codegen/sdk/python/expressions/test_number.py @@ -3,7 +3,7 @@ from codegen.sdk.core.expressions.number import Number -def test_number_parse(tmpdir): +def test_number_parse(tmpdir) -> None: # language=python content = """ a = 1 diff --git a/tests/unit/codegen/sdk/python/expressions/test_type.py b/tests/unit/codegen/sdk/python/expressions/test_type.py index 22cd7f8f5..a90325d16 100644 --- a/tests/unit/codegen/sdk/python/expressions/test_type.py +++ b/tests/unit/codegen/sdk/python/expressions/test_type.py @@ -9,7 +9,7 @@ from codegen.sdk.python import PyAssignment, PyClass, PyFunction -def test_type_basic(tmpdir): +def test_type_basic(tmpdir) -> None: file = "test.py" # language=python content = """ @@ -35,7 +35,7 @@ def foo(a: str): ) -def test_type_generic(tmpdir): +def test_type_generic(tmpdir) -> None: file = "test.py" # language=python content = """ @@ -63,7 +63,7 @@ def foo(a: tuple[int, int, str]): ) -def test_type_union(tmpdir): +def test_type_union(tmpdir) -> None: file = "test.py" # language=python content = """ @@ -90,7 +90,7 @@ def foo(a: int | None | str): ) -def test_type_multi_file(tmpdir): +def test_type_multi_file(tmpdir) -> None: file2 = "test2.py" # language=python content2 = """ diff --git a/tests/unit/codegen/sdk/python/function/test_function_async.py b/tests/unit/codegen/sdk/python/function/test_function_async.py index 796b323b7..79bceadc0 100644 --- a/tests/unit/codegen/sdk/python/function/test_function_async.py +++ b/tests/unit/codegen/sdk/python/function/test_function_async.py @@ -2,7 +2,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_function_is_async_basic(tmpdir): +def test_function_is_async_basic(tmpdir) -> None: # language=python content = """ def foo(): @@ -56,7 +56,7 @@ async def qux(self): ) -def test_function_is_async_extended(tmpdir): +def test_function_is_async_extended(tmpdir) -> None: # language=python content = """ @my_decorator diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py b/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py index c31f55f22..fd3d817d8 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_is_dynamic.py @@ -2,7 +2,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_py_import_is_dynamic_in_function(tmpdir): +def test_py_import_is_dynamic_in_function(tmpdir) -> None: # language=python content = """ def my_function(): @@ -23,7 +23,7 @@ def my_function(): assert not imports[2].is_dynamic # import static_import -def test_py_import_is_dynamic_in_if_block(tmpdir): +def test_py_import_is_dynamic_in_if_block(tmpdir) -> None: # language=python content = """ import top_level # Static import @@ -41,7 +41,7 @@ def test_py_import_is_dynamic_in_if_block(tmpdir): assert imports[2].is_dynamic # from x import y -def test_py_import_is_dynamic_in_try_except(tmpdir): +def test_py_import_is_dynamic_in_try_except(tmpdir) -> None: # language=python content = """ import static_first # Static import @@ -61,7 +61,7 @@ def test_py_import_is_dynamic_in_try_except(tmpdir): assert imports[2].is_dynamic # from x.y import z -def test_py_import_is_dynamic_in_with_block(tmpdir): +def test_py_import_is_dynamic_in_with_block(tmpdir) -> None: # language=python content = """ import static_import # Static import @@ -79,7 +79,7 @@ def test_py_import_is_dynamic_in_with_block(tmpdir): assert imports[2].is_dynamic # from a.b import c -def test_py_import_is_dynamic_in_class_method(tmpdir): +def test_py_import_is_dynamic_in_class_method(tmpdir) -> None: # language=python content = """ import static_import # Static import @@ -103,7 +103,7 @@ def class_method(cls): assert imports[3].is_dynamic # another_dynamic import -def test_py_import_is_dynamic_in_nested_function(tmpdir): +def test_py_import_is_dynamic_in_nested_function(tmpdir) -> None: # language=python content = """ import static_import # Static import @@ -125,7 +125,7 @@ def inner_function(): assert imports[3].is_dynamic # from x import y -def test_py_import_is_dynamic_in_else_clause(tmpdir): +def test_py_import_is_dynamic_in_else_clause(tmpdir) -> None: # language=python content = """ import static_import # Static import @@ -145,7 +145,7 @@ def test_py_import_is_dynamic_in_else_clause(tmpdir): assert imports[2].is_dynamic # from x import y -def test_py_import_is_dynamic_in_except_clause(tmpdir): +def test_py_import_is_dynamic_in_except_clause(tmpdir) -> None: # language=python content = """ import static_import # Static import @@ -165,7 +165,7 @@ def test_py_import_is_dynamic_in_except_clause(tmpdir): assert imports[2].is_dynamic # from x import y -def test_py_import_is_dynamic_in_finally_clause(tmpdir): +def test_py_import_is_dynamic_in_finally_clause(tmpdir) -> None: # language=python content = """ import static_import # Static import @@ -187,7 +187,7 @@ def test_py_import_is_dynamic_in_finally_clause(tmpdir): assert imports[2].is_dynamic # from x import y -def test_py_import_is_dynamic_in_while_statement(tmpdir): +def test_py_import_is_dynamic_in_while_statement(tmpdir) -> None: # language=python content = """ import static_import # Static import @@ -205,7 +205,7 @@ def test_py_import_is_dynamic_in_while_statement(tmpdir): assert imports[2].is_dynamic # from a import b -def test_py_import_is_dynamic_in_match_case(tmpdir): +def test_py_import_is_dynamic_in_match_case(tmpdir) -> None: # language=python content = """ import static_import # Static import diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_reduce_block.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_reduce_block.py index d1eaecd13..d8141f0ea 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_reduce_block.py +++ b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_reduce_block.py @@ -9,7 +9,7 @@ from codegen.sdk.python import PyFile -def test_reduce_condition_to_true_elif(tmpdir): +def test_reduce_condition_to_true_elif(tmpdir) -> None: # language=python content = """ def foo(): @@ -41,7 +41,7 @@ def foo(): ) -def test_reduce_condition_to_false_elif(tmpdir): +def test_reduce_condition_to_false_elif(tmpdir) -> None: # language=python content = """ def foo(): @@ -71,7 +71,7 @@ def foo(): ) -def test_reduce_condition_to_true_else(tmpdir): +def test_reduce_condition_to_true_else(tmpdir) -> None: # language=python content = """ def foo(): @@ -97,7 +97,7 @@ def foo(): ) -def test_reduce_condition_to_false_else(tmpdir): +def test_reduce_condition_to_false_else(tmpdir) -> None: # language=python content = """ def foo(): @@ -123,7 +123,7 @@ def foo(): ) -def test_reduce_condition_multiple_if_blocks(tmpdir): +def test_reduce_condition_multiple_if_blocks(tmpdir) -> None: # language=python content = """ def foo(): @@ -155,7 +155,7 @@ def foo(): ) -def test_reduce_condition_nested_if(tmpdir): +def test_reduce_condition_nested_if(tmpdir) -> None: # language=python content = """ class MyClass: @@ -198,7 +198,7 @@ def foo(self): ) -def test_reduce_condition_else_if_to_true(tmpdir): +def test_reduce_condition_else_if_to_true(tmpdir) -> None: # language=python content = """ def foo(): @@ -234,7 +234,7 @@ def foo(): ) -def test_reduce_condition_else_if_to_false(tmpdir): +def test_reduce_condition_else_if_to_false(tmpdir) -> None: # language=python content = """ def foo(): @@ -268,7 +268,7 @@ def foo(): ) -def test_reduce_condition_second_else_if_to_true(tmpdir): +def test_reduce_condition_second_else_if_to_true(tmpdir) -> None: # language=python content = """ def foo(): @@ -300,7 +300,7 @@ def foo(): ) -def test_reduce_condition_second_else_if_to_false(tmpdir): +def test_reduce_condition_second_else_if_to_false(tmpdir) -> None: # language=python content = """ def foo(): @@ -332,7 +332,7 @@ def foo(): ) -def test_reduce_condition_else_statement_raises(tmpdir): +def test_reduce_condition_else_statement_raises(tmpdir) -> None: # language=python content = """ def foo(): @@ -351,7 +351,7 @@ def foo(): else_block.reduce_condition(True) -def test_reduce_condition_to_true_single_if(tmpdir): +def test_reduce_condition_to_true_single_if(tmpdir) -> None: # language=python content = """ def foo(): @@ -375,7 +375,7 @@ def foo(): ) -def test_reduce_condition_to_false_single_if(tmpdir): +def test_reduce_condition_to_false_single_if(tmpdir) -> None: # language=python content = """ def foo(): diff --git a/tests/unit/codegen/sdk/python/statements/with_statement/test_with_statement_properties.py b/tests/unit/codegen/sdk/python/statements/with_statement/test_with_statement_properties.py index db139e11e..45b129002 100644 --- a/tests/unit/codegen/sdk/python/statements/with_statement/test_with_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/with_statement/test_with_statement_properties.py @@ -8,7 +8,7 @@ from codegen.sdk.python import PyFile -def test_with_statement_single_var(tmpdir): +def test_with_statement_single_var(tmpdir) -> None: # language=python content = """ def foo(): @@ -38,7 +38,7 @@ def foo(): ) -def test_with_statement_aliased_var(tmpdir): +def test_with_statement_aliased_var(tmpdir) -> None: # language=python content = """ def foo(): @@ -67,7 +67,7 @@ def foo(): ) -def test_with_statement_multiple_vars(tmpdir): +def test_with_statement_multiple_vars(tmpdir) -> None: # language=python content = """ def foo(): @@ -96,7 +96,7 @@ def foo(): ) -def test_with_statement_multiline_vars(tmpdir): +def test_with_statement_multiline_vars(tmpdir) -> None: # language=python content = """ def foo(): @@ -137,7 +137,7 @@ def foo(): ) -def test_with_statement_function_call(tmpdir): +def test_with_statement_function_call(tmpdir) -> None: # language=python content = """ def foo(): diff --git a/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_indent.py b/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_indent.py index 35e7622be..e0e63548d 100644 --- a/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_indent.py +++ b/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_indent.py @@ -3,7 +3,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_indent_once_to_left(tmpdir): +def test_indent_once_to_left(tmpdir) -> None: # language=typescript content = """ function foo(): number { @@ -30,7 +30,7 @@ def test_indent_once_to_left(tmpdir): ) -def test_indent_left_exceed_limit(tmpdir): +def test_indent_left_exceed_limit(tmpdir) -> None: # language=typescript content = """ function foo(): number { @@ -57,7 +57,7 @@ def test_indent_left_exceed_limit(tmpdir): ) -def test_indent_right_once(tmpdir): +def test_indent_right_once(tmpdir) -> None: # language=typescript content = """ function foo(): number { @@ -84,7 +84,7 @@ def test_indent_right_once(tmpdir): ) -def test_indent_right_multiple(tmpdir): +def test_indent_right_multiple(tmpdir) -> None: # language=typescript content = """ function foo(): number { @@ -111,7 +111,7 @@ def test_indent_right_multiple(tmpdir): ) -def test_indent_zero(tmpdir): +def test_indent_zero(tmpdir) -> None: # language=typescript content = """ function foo(): number { @@ -128,7 +128,7 @@ def test_indent_zero(tmpdir): assert file.content == content -def test_indent_nested_code_blocks(tmpdir): +def test_indent_nested_code_blocks(tmpdir) -> None: # language=typescript content = """ function foo(): number { @@ -201,7 +201,7 @@ def test_indent_nested_code_blocks(tmpdir): ) -def test_indent_only_nested_block(tmpdir): +def test_indent_only_nested_block(tmpdir) -> None: # language=typescript content = """ function foo(a: boolean, b: any, c: any, d: boolean): void { diff --git a/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_unwrap.py b/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_unwrap.py index 0ebccc5b5..e37d109dd 100644 --- a/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_unwrap.py +++ b/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_unwrap.py @@ -8,7 +8,7 @@ from codegen.sdk.typescript.file import TSFile -def test_unwrap_flat_statements(tmpdir): +def test_unwrap_flat_statements(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -40,7 +40,7 @@ def test_unwrap_flat_statements(tmpdir): ) -def test_unwrap_nested_statements(tmpdir): +def test_unwrap_nested_statements(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -83,7 +83,7 @@ def test_unwrap_nested_statements(tmpdir): ) -def test_unwrap_sandwiched_statements(tmpdir): +def test_unwrap_sandwiched_statements(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -124,7 +124,7 @@ def test_unwrap_sandwiched_statements(tmpdir): ) -def test_unwrap_multiline_wrapper(tmpdir): +def test_unwrap_multiline_wrapper(tmpdir) -> None: # language=typescript content = """ function foo(): void { diff --git a/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_wrap.py b/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_wrap.py index 0a52dece6..2b3c805bf 100644 --- a/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_wrap.py +++ b/tests/unit/codegen/sdk/typescript/detached_symbols/code_block/test_code_block_wrap.py @@ -2,7 +2,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_wrap_with_if_statement(tmpdir): +def test_wrap_with_if_statement(tmpdir) -> None: # language=typescript jsx content = """ function funcA(a) { diff --git a/tests/unit/codegen/sdk/typescript/export/test_export_aliased.py b/tests/unit/codegen/sdk/typescript/export/test_export_aliased.py index 48bc5a502..ed5dc6a07 100644 --- a/tests/unit/codegen/sdk/typescript/export/test_export_aliased.py +++ b/tests/unit/codegen/sdk/typescript/export/test_export_aliased.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_export_aliased(tmpdir): +def test_export_aliased(tmpdir) -> None: # language=typescript content = """ function foo() {} @@ -31,7 +31,7 @@ def test_export_aliased(tmpdir): assert fuzz_export.resolved_symbol == file.get_function("fuzz") -def test_export_aliased_multiple(tmpdir): +def test_export_aliased_multiple(tmpdir) -> None: # language=typescript content = """ function foo() {} @@ -60,7 +60,7 @@ def test_export_aliased_multiple(tmpdir): assert baz_export.resolved_symbol == file.get_function("baz") -def test_export_aliased_default(tmpdir): +def test_export_aliased_default(tmpdir) -> None: # language=typescript content = """ function foo() {} diff --git a/tests/unit/codegen/sdk/typescript/export/test_export_export_statement.py b/tests/unit/codegen/sdk/typescript/export/test_export_export_statement.py index c9ae8a410..113a591ef 100644 --- a/tests/unit/codegen/sdk/typescript/export/test_export_export_statement.py +++ b/tests/unit/codegen/sdk/typescript/export/test_export_export_statement.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_export_export_statement(tmpdir): +def test_export_export_statement(tmpdir) -> None: # language=typescript content = """ function foo() {} @@ -29,7 +29,7 @@ def test_export_export_statement(tmpdir): ] -def test_remove_export(tmpdir): +def test_remove_export(tmpdir) -> None: # language=typescript content = """ function foo() {} @@ -55,7 +55,7 @@ def test_remove_export(tmpdir): ) -def test_remove_all_exports(tmpdir): +def test_remove_all_exports(tmpdir) -> None: # language=typescript content = """ function foo() {} @@ -77,7 +77,7 @@ def test_remove_all_exports(tmpdir): ) -def test_add_export(tmpdir): +def test_add_export(tmpdir) -> None: # language=typescript content = """ function foo() {} @@ -101,7 +101,7 @@ def test_add_export(tmpdir): ) -def test_insert_export_at_start(tmpdir): +def test_insert_export_at_start(tmpdir) -> None: # language=typescript content = """ function foo() {} @@ -125,7 +125,7 @@ def test_insert_export_at_start(tmpdir): ) -def test_remove_middle_export(tmpdir): +def test_remove_middle_export(tmpdir) -> None: # language=typescript content = """ function foo() {} @@ -150,7 +150,7 @@ def test_remove_middle_export(tmpdir): ) -def test_remove_inline_export(tmpdir): +def test_remove_inline_export(tmpdir) -> None: # language=typescript content = """ export function foo() {} @@ -172,7 +172,7 @@ def test_remove_inline_export(tmpdir): ) -def test_add_remove_reexport_inline(tmpdir): +def test_add_remove_reexport_inline(tmpdir) -> None: # language=typescript content = """ export { foo, bar } from './other'; @@ -198,7 +198,7 @@ def test_add_remove_reexport_inline(tmpdir): ) -def test_remove_namespace_export(tmpdir): +def test_remove_namespace_export(tmpdir) -> None: # language=typescript content = """ export namespace Foo { @@ -224,7 +224,7 @@ def test_remove_namespace_export(tmpdir): ) -def test_remove_add_wildcard_export(tmpdir): +def test_remove_add_wildcard_export(tmpdir) -> None: # language=typescript content = """ export * from './foo'; @@ -245,7 +245,7 @@ def test_remove_add_wildcard_export(tmpdir): ) -def test_remove_all_exports_multiline(tmpdir): +def test_remove_all_exports_multiline(tmpdir) -> None: # language=typescript content = """ function foo() {} diff --git a/tests/unit/codegen/sdk/typescript/export/test_external_export.py b/tests/unit/codegen/sdk/typescript/export/test_external_export.py index 56c169aeb..84e74d0ac 100644 --- a/tests/unit/codegen/sdk/typescript/export/test_external_export.py +++ b/tests/unit/codegen/sdk/typescript/export/test_external_export.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_is_external_export_true(tmpdir): +def test_is_external_export_true(tmpdir) -> None: # language=typescript content = """ export { default as React } from "react"; @@ -22,7 +22,7 @@ def test_is_external_export_true(tmpdir): assert file.exports[0].is_external_export is True -def test_is_external_export_false(tmpdir): +def test_is_external_export_false(tmpdir) -> None: # language=typescript content = """ export { foo } from "./foo"; @@ -40,7 +40,7 @@ def test_is_external_export_false(tmpdir): assert file.exports[0].is_external_export is False -def test_multiple_external_exports(tmpdir): +def test_multiple_external_exports(tmpdir) -> None: # language=typescript content = """ export { default as React } from "react"; @@ -57,7 +57,7 @@ def test_multiple_external_exports(tmpdir): assert all(export.is_external_export for export in file.exports) -def test_mixed_internal_external_exports(tmpdir): +def test_mixed_internal_external_exports(tmpdir) -> None: # language=typescript content = """ export { default as lodash } from "lodash"; @@ -80,7 +80,7 @@ def test_mixed_internal_external_exports(tmpdir): assert file.exports[2].is_external_export is False -def test_nested_reexports(tmpdir): +def test_nested_reexports(tmpdir) -> None: # language=typescript with get_codebase_session( tmpdir=tmpdir, @@ -106,7 +106,7 @@ def test_nested_reexports(tmpdir): assert services_file.exports[0].is_external_export is False -def test_wildcard_exports(tmpdir): +def test_wildcard_exports(tmpdir) -> None: # language=typescript with get_codebase_session( tmpdir=tmpdir, diff --git a/tests/unit/codegen/sdk/typescript/export/test_get_reexported_exports.py b/tests/unit/codegen/sdk/typescript/export/test_get_reexported_exports.py index e035e2030..ecd60e762 100644 --- a/tests/unit/codegen/sdk/typescript/export/test_get_reexported_exports.py +++ b/tests/unit/codegen/sdk/typescript/export/test_get_reexported_exports.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_get_reexported_exports(tmpdir): +def test_get_reexported_exports(tmpdir) -> None: # language=typescript with get_codebase_session( tmpdir=tmpdir, @@ -42,7 +42,7 @@ def test_get_reexported_exports(tmpdir): assert len(reexports) == 0 -def test_get_reexported_export_with_external_module_on_export(tmpdir): +def test_get_reexported_export_with_external_module_on_export(tmpdir) -> None: # language=typescript with get_codebase_session( tmpdir=tmpdir, @@ -70,7 +70,7 @@ def test_get_reexported_export_with_external_module_on_export(tmpdir): assert len(reexports) == 2 -def test_get_reexported_export_with_alias(tmpdir): +def test_get_reexported_export_with_alias(tmpdir) -> None: # language=typescript with get_codebase_session( tmpdir=tmpdir, diff --git a/tests/unit/codegen/sdk/typescript/expressions/ternary_expression/test_ternary_reduce_condition.py b/tests/unit/codegen/sdk/typescript/expressions/ternary_expression/test_ternary_reduce_condition.py index fac383111..608ba5f88 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/ternary_expression/test_ternary_reduce_condition.py +++ b/tests/unit/codegen/sdk/typescript/expressions/ternary_expression/test_ternary_reduce_condition.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_reduce_ternary_condition_to_true(tmpdir): +def test_reduce_ternary_condition_to_true(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -32,7 +32,7 @@ def test_reduce_ternary_condition_to_true(tmpdir): ) -def test_reduce_ternary_condition_to_false(tmpdir): +def test_reduce_ternary_condition_to_false(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -57,7 +57,7 @@ def test_reduce_ternary_condition_to_false(tmpdir): ) -def test_reduce_nested_ternary_condition_to_true_and_false(tmpdir): +def test_reduce_nested_ternary_condition_to_true_and_false(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -84,7 +84,7 @@ def test_reduce_nested_ternary_condition_to_true_and_false(tmpdir): ) -def test_reduce_nested_ternary_condition_outer_false(tmpdir): +def test_reduce_nested_ternary_condition_outer_false(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -109,7 +109,7 @@ def test_reduce_nested_ternary_condition_outer_false(tmpdir): ) -def test_reduce_multiple_ternary_conditions(tmpdir): +def test_reduce_multiple_ternary_conditions(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -138,7 +138,7 @@ def test_reduce_multiple_ternary_conditions(tmpdir): ) -def test_reduce_ternary_condition_with_function_call(tmpdir): +def test_reduce_ternary_condition_with_function_call(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -163,7 +163,7 @@ def test_reduce_ternary_condition_with_function_call(tmpdir): ) -def test_reduce_ternary_condition_with_dict(tmpdir): +def test_reduce_ternary_condition_with_dict(tmpdir) -> None: # language=typescript jsx # language=typescript content = """ @@ -198,7 +198,7 @@ def test_reduce_ternary_condition_with_dict(tmpdir): ) -def test_reduce_ternary_condition_with_dict_complex(tmpdir): +def test_reduce_ternary_condition_with_dict_complex(tmpdir) -> None: # language=typescript jsx # language=typescript content = """ @@ -234,7 +234,7 @@ def test_reduce_ternary_condition_with_dict_complex(tmpdir): ) -def test_reduce_ternary_condition_with_dict_trailing_comma(tmpdir): +def test_reduce_ternary_condition_with_dict_trailing_comma(tmpdir) -> None: # language=typescript content = """ function foo(): { a: number } { diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_binary_expression.py b/tests/unit/codegen/sdk/typescript/expressions/test_binary_expression.py index 9d4f26f75..4d862abfa 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/test_binary_expression.py +++ b/tests/unit/codegen/sdk/typescript/expressions/test_binary_expression.py @@ -8,7 +8,7 @@ from codegen.sdk.typescript.file import TSFile -def test_binary_expressions(tmpdir): +def test_binary_expressions(tmpdir) -> None: # language=typescript content = """ let a: number = 1 + 2; @@ -127,7 +127,7 @@ def test_binary_expressions(tmpdir): assert file.get_global_var("v").value.operator.source == "in" -def test_chained_binary_expressions(tmpdir): +def test_chained_binary_expressions(tmpdir) -> None: # language=typescript content = """ // binary operators diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_binary_reduce_condition.py b/tests/unit/codegen/sdk/typescript/expressions/test_binary_reduce_condition.py index b21ef842b..a99a21bd7 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/test_binary_reduce_condition.py +++ b/tests/unit/codegen/sdk/typescript/expressions/test_binary_reduce_condition.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_reduce_binary_simple(tmpdir): +def test_reduce_binary_simple(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -33,7 +33,7 @@ def test_reduce_binary_simple(tmpdir): ) -def test_reduce_binary_object(tmpdir): +def test_reduce_binary_object(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -62,7 +62,7 @@ def test_reduce_binary_object(tmpdir): ) -def test_reduce_binary_complex_condition(tmpdir): +def test_reduce_binary_complex_condition(tmpdir) -> None: # language=typescript content = """ function bar(): void { @@ -91,7 +91,7 @@ def test_reduce_binary_complex_condition(tmpdir): ) -def test_reduce_negation_condition(tmpdir): +def test_reduce_negation_condition(tmpdir) -> None: # language=typescript content = """ function baz(): void { @@ -119,7 +119,7 @@ def test_reduce_negation_condition(tmpdir): ) -def test_reduce_jsx_element(tmpdir): +def test_reduce_jsx_element(tmpdir) -> None: # language=typescript jsx content = """ const MyComponent: React.FC = () => { @@ -154,7 +154,7 @@ def test_reduce_jsx_element(tmpdir): ) -def test_reduce_jsx_element_keep(tmpdir): +def test_reduce_jsx_element_keep(tmpdir) -> None: # language=typescript jsx content = """ const MyComponent: React.FC = () => { @@ -188,7 +188,7 @@ def test_reduce_jsx_element_keep(tmpdir): ) -def test_reduce_complex_condition_with_JSX(tmpdir): +def test_reduce_complex_condition_with_JSX(tmpdir) -> None: # language=typescript jsx content = """ const AnotherComponent: React.FC = () => { diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_boolean.py b/tests/unit/codegen/sdk/typescript/expressions/test_boolean.py index 7ce20bc67..5e2b5e668 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/test_boolean.py +++ b/tests/unit/codegen/sdk/typescript/expressions/test_boolean.py @@ -5,7 +5,7 @@ from codegen.sdk.typescript.assignment import TSAssignment -def test_boolean_parse(tmpdir): +def test_boolean_parse(tmpdir) -> None: # language=typescript content = """ const a = true; diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_dict.py b/tests/unit/codegen/sdk/typescript/expressions/test_dict.py index 4478340dc..b9143da32 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/test_dict.py +++ b/tests/unit/codegen/sdk/typescript/expressions/test_dict.py @@ -9,7 +9,7 @@ from codegen.sdk.typescript.function import TSFunction -def test_dict_basic(tmpdir): +def test_dict_basic(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -36,7 +36,7 @@ def test_dict_basic(tmpdir): ) -def test_dict_multiline(tmpdir): +def test_dict_multiline(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -73,7 +73,7 @@ def test_dict_multiline(tmpdir): ) -def test_dict_insert(tmpdir): +def test_dict_insert(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -98,7 +98,7 @@ def test_dict_insert(tmpdir): @pytest.mark.parametrize("removes,inserts", cases, ids=[f"{removes=}-{inserts=}" for removes, inserts in cases]) -def test_dict_interleaved(tmpdir, removes, inserts): +def test_dict_interleaved(tmpdir, removes, inserts) -> None: ref_dict = {str(-1 + -i): -(i**2) for i in range(removes)} file = "test.ts" content = f""" @@ -124,7 +124,7 @@ def test_dict_interleaved(tmpdir, removes, inserts): @pytest.mark.parametrize("removes,inserts", cases, ids=[f"{removes=}-{inserts=}" for removes, inserts in cases]) -def test_dict_removes_first(tmpdir, removes, inserts): +def test_dict_removes_first(tmpdir, removes, inserts) -> None: ref_dict = {str(-1 + -i): -(i**2) for i in range(removes)} file = "test.ts" content = f""" @@ -149,7 +149,7 @@ def test_dict_removes_first(tmpdir, removes, inserts): @pytest.mark.parametrize("removes,inserts", cases, ids=[f"{removes=}-{inserts=}" for removes, inserts in cases]) -def test_dict_inserts_first(tmpdir, removes, inserts): +def test_dict_inserts_first(tmpdir, removes, inserts) -> None: ref_dict = {str(-1 + -i): -(i**2) for i in range(removes)} file = "test.ts" content = f""" @@ -174,7 +174,7 @@ def test_dict_inserts_first(tmpdir, removes, inserts): @pytest.mark.parametrize("existing,inserts", cases, ids=[f"{existing=}-{inserts=}" for existing, inserts in cases]) -def test_dict_append_existing(tmpdir, existing, inserts): +def test_dict_append_existing(tmpdir, existing, inserts) -> None: ref_dict = {str(-1 + -i): -(i**2) for i in range(existing)} file = "test.ts" content = f""" @@ -196,7 +196,7 @@ def test_dict_append_existing(tmpdir, existing, inserts): @pytest.mark.parametrize("existing", list(range(4)), ids=[f"existing={existing}" for existing in range(4)]) -def test_dict_set_existing(tmpdir, existing): +def test_dict_set_existing(tmpdir, existing) -> None: ref_dict = {str(-1 + -i): -(i**2) for i in range(existing)} file = "test.ts" content = f""" @@ -218,7 +218,7 @@ def test_dict_set_existing(tmpdir, existing): @pytest.mark.parametrize("existing,inserts", cases, ids=[f"existing={existing + 1}-{inserts=}" for existing, inserts in cases]) -def test_dict_set_existing_same(tmpdir, existing, inserts): +def test_dict_set_existing_same(tmpdir, existing, inserts) -> None: ref_dict = {str(-1 + -i): -(i**2) for i in range(existing)} file = "test.ts" content = f""" @@ -239,7 +239,7 @@ def test_dict_set_existing_same(tmpdir, existing, inserts): ) -def test_dict_empty(tmpdir): +def test_dict_empty(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -263,7 +263,7 @@ def test_dict_empty(tmpdir): ) -def test_dict_remove_insert(tmpdir): +def test_dict_remove_insert(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -286,7 +286,7 @@ def test_dict_remove_insert(tmpdir): ) -def test_dict_shorthand_remove_insert(tmpdir): +def test_dict_shorthand_remove_insert(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -309,7 +309,7 @@ def test_dict_shorthand_remove_insert(tmpdir): ) -def test_dict_edit(tmpdir): +def test_dict_edit(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -330,7 +330,7 @@ def test_dict_edit(tmpdir): ) -def test_dict_edit_shorthand(tmpdir): +def test_dict_edit_shorthand(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -351,7 +351,7 @@ def test_dict_edit_shorthand(tmpdir): ) -def test_dict_edit_not_shorthand(tmpdir): +def test_dict_edit_not_shorthand(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -372,7 +372,7 @@ def test_dict_edit_not_shorthand(tmpdir): ) -def test_dict_clear(tmpdir): +def test_dict_clear(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -394,7 +394,7 @@ def test_dict_clear(tmpdir): @pytest.mark.xfail(reason="Not implemented Yet") -def test_dict_obj(tmpdir): +def test_dict_obj(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -421,7 +421,7 @@ def test_dict_obj(tmpdir): ) -def test_dict_shorthand(tmpdir): +def test_dict_shorthand(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -448,7 +448,7 @@ def test_dict_shorthand(tmpdir): ) -def test_shorthand_multiline(tmpdir): +def test_shorthand_multiline(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -485,7 +485,7 @@ def test_shorthand_multiline(tmpdir): ) -def test_convert_shorthand(tmpdir): +def test_convert_shorthand(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -509,7 +509,7 @@ def test_convert_shorthand(tmpdir): ) -def test_dict_shorthand_insert(tmpdir): +def test_dict_shorthand_insert(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -530,7 +530,7 @@ def test_dict_shorthand_insert(tmpdir): ) -def test_dict_function_values(tmpdir): +def test_dict_function_values(tmpdir) -> None: # language=typescript content = """ export const mapper = { @@ -559,7 +559,7 @@ def test_dict_function_values(tmpdir): assert len(list(d.items())) == 3 assert list(d.keys()) == ["method1", "method2", "method3"] assert all(isinstance(v, TSFunction) for v in list(d.values())) - for key, func_def in d.items(): + for func_def in d.values(): func_def.insert_before("async ", newline=False, extended=False) func_def.set_return_type(f"Promise<{func_def.return_type.source}>") @@ -588,7 +588,7 @@ def test_dict_function_values(tmpdir): ) -def test_dict_usage(tmpdir): +def test_dict_usage(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -604,7 +604,7 @@ def test_dict_usage(tmpdir): assert {*foo.symbol_usages} == {obj} -def test_dict_usage_shorthand(tmpdir): +def test_dict_usage_shorthand(tmpdir) -> None: file = "test.ts" # language=typescript content = """ @@ -620,7 +620,7 @@ def test_dict_usage_shorthand(tmpdir): assert {*foo.symbol_usages} == {obj} -def test_dict_usage_spread(tmpdir): +def test_dict_usage_spread(tmpdir) -> None: file = "test.ts" # language=typescript jsx # language=typescript diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_expression.py b/tests/unit/codegen/sdk/typescript/expressions/test_expression.py index 3e66e8711..890a6c850 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/test_expression.py +++ b/tests/unit/codegen/sdk/typescript/expressions/test_expression.py @@ -2,7 +2,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_function_calls_from_function_call(tmpdir): +def test_function_calls_from_function_call(tmpdir) -> None: # language=typescript content = """ describe("LoadManager test", () => { @@ -19,7 +19,7 @@ def test_function_calls_from_function_call(tmpdir): assert len(func_call) == 6 -def test_remove_parenthesized(tmpdir): +def test_remove_parenthesized(tmpdir) -> None: # language=typescript content = """ const a = (b) || c; @@ -37,7 +37,7 @@ def test_remove_parenthesized(tmpdir): ) -def test_remove_tsx(tmpdir): +def test_remove_tsx(tmpdir) -> None: # language=typescript jsx content = """ const element =

Hello, {name}

; diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_number.py b/tests/unit/codegen/sdk/typescript/expressions/test_number.py index a191ff95b..82820eb99 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/test_number.py +++ b/tests/unit/codegen/sdk/typescript/expressions/test_number.py @@ -4,7 +4,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_number_parse(tmpdir): +def test_number_parse(tmpdir) -> None: # language=typescript content = """ const a = 1; diff --git a/tests/unit/codegen/sdk/typescript/file/test_file_export_statements.py b/tests/unit/codegen/sdk/typescript/file/test_file_export_statements.py index adaa00922..a8117fcbb 100644 --- a/tests/unit/codegen/sdk/typescript/file/test_file_export_statements.py +++ b/tests/unit/codegen/sdk/typescript/file/test_file_export_statements.py @@ -20,7 +20,7 @@ def _get_file(content: str): @pytest.fixture def export_check(ts_file): - def _export_check(*sources): + def _export_check(*sources) -> None: file = ts_file("\n".join(sources)) assert len(file.export_statements) == len(sources) for i, source in enumerate(sources): @@ -32,137 +32,137 @@ def _export_check(*sources): # 1. Named exports -def test_named_exports__export_const(export_check): +def test_named_exports__export_const(export_check) -> None: export_check("export const variable = value;") -def test_named_exports__export_let(export_check): +def test_named_exports__export_let(export_check) -> None: export_check("export let variable = value;") -def test_named_exports__export_var(export_check): +def test_named_exports__export_var(export_check) -> None: export_check("export var variable = value;") -def test_named_exports__export_function(export_check): +def test_named_exports__export_function(export_check) -> None: export_check("export function functionName() { /* ... */ }") -def test_named_exports__export_class(export_check): +def test_named_exports__export_class(export_check) -> None: export_check("export class ClassName { /* ... */ }") -def test_named_exports__export_interface(export_check): +def test_named_exports__export_interface(export_check) -> None: export_check("export interface InterfaceName { /* ... */ }") -def test_named_exports__export_type(export_check): +def test_named_exports__export_type(export_check) -> None: export_check("export type TypeName = value;") -def test_named_exports__export_enum(export_check): +def test_named_exports__export_enum(export_check) -> None: export_check("export enum EnumName { /* ... */ }") # 2. Default exports -def test_default_exports__export_default_value(export_check): +def test_default_exports__export_default_value(export_check) -> None: export_check("export default value;") -def test_default_exports__export_default_function(export_check): +def test_default_exports__export_default_function(export_check) -> None: export_check("export default function() { /* ... */ }") -def test_default_exports__export_default_class(export_check): +def test_default_exports__export_default_class(export_check) -> None: export_check("export default class { /* ... */ }") # 3. Exporting declarations after they're defined -def test_defined_declarations__just_export_statement(export_check): +def test_defined_declarations__just_export_statement(export_check) -> None: export_check("export { variable, functionName, ClassName };") # 4. Exporting with aliases -def test_exporting_with_aliases(export_check): +def test_exporting_with_aliases(export_check) -> None: export_check("export { variable as aliasName };") # 5. Re-exporting -def test_reexporting__export_many_names(export_check): +def test_reexporting__export_many_names(export_check) -> None: export_check("export { name1, name2 } from './other-module';") -def test_reexporting__export_many_names_with_aliases(export_check): +def test_reexporting__export_many_names_with_aliases(export_check) -> None: export_check("export { import1 as name1, import2 as name2 } from './other-module';") -def test_reexporting__export_star(export_check): +def test_reexporting__export_star(export_check) -> None: export_check("export * from './other-module';") -def test_reexporting__export_default(export_check): +def test_reexporting__export_default(export_check) -> None: export_check("export { default } from './other-module';") # 6. Combining default and named exports -def test_default_named__export_default_class(export_check): +def test_default_named__export_default_class(export_check) -> None: export_check("export default class { /* ... */ }") -def test_default_named__export_const(export_check): +def test_default_named__export_const(export_check) -> None: export_check("export const namedExport = value;") # 7. Exporting types -def test_types__export_type_block(export_check): +def test_types__export_type_block(export_check) -> None: export_check("export type { SomeType } from './types';") # 8. Exporting namespaces -def test_export_namespace(export_check): +def test_export_namespace(export_check) -> None: export_check("export namespace MyNamespace { export const something = value; }") # 9. Exporting as a namespace -def test_export_star_as_namespace(export_check): +def test_export_star_as_namespace(export_check) -> None: export_check("export * as myNamespace from './module';") # 10. Exporting declarations with modifiers -def test_declaration_with_modifiers__export_abstract_class(export_check): +def test_declaration_with_modifiers__export_abstract_class(export_check) -> None: export_check("export abstract class AbstractClass { /* ... */ }") -def test_declaration_with_modifiers__export_const_enum(export_check): +def test_declaration_with_modifiers__export_const_enum(export_check) -> None: export_check("export const enum ConstEnum { /* ... */ }") # 11. Exporting with generics -def test_generics__export_function(export_check): +def test_generics__export_function(export_check) -> None: export_check("export function genericFunction() { /* ... */ }") -def test_generics__export_class(export_check): +def test_generics__export_class(export_check) -> None: export_check("export class GenericClass { /* ... */ }") diff --git a/tests/unit/codegen/sdk/typescript/function/test_function_async.py b/tests/unit/codegen/sdk/typescript/function/test_function_async.py index 8383b911b..9ac41feaa 100644 --- a/tests/unit/codegen/sdk/typescript/function/test_function_async.py +++ b/tests/unit/codegen/sdk/typescript/function/test_function_async.py @@ -2,7 +2,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_function_is_async_basic(tmpdir): +def test_function_is_async_basic(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -66,7 +66,7 @@ class MyClass { ) -def test_function_is_async_extended(tmpdir): +def test_function_is_async_extended(tmpdir) -> None: # language=typescript content = """ /** Docstring */ @@ -146,7 +146,7 @@ class MyClass { ) -def test_function_is_async_other_syntax(tmpdir): +def test_function_is_async_other_syntax(tmpdir) -> None: # language=typescript content = """ // Arrow functions diff --git a/tests/unit/codegen/sdk/typescript/function/test_function_calls_from_symbol.py b/tests/unit/codegen/sdk/typescript/function/test_function_calls_from_symbol.py index 871518f3a..2e00b5687 100644 --- a/tests/unit/codegen/sdk/typescript/function/test_function_calls_from_symbol.py +++ b/tests/unit/codegen/sdk/typescript/function/test_function_calls_from_symbol.py @@ -13,7 +13,7 @@ from codegen.sdk.typescript.statements.assignment_statement import TSAssignmentStatement -def test_function_calls_from_file(tmpdir): +def test_function_calls_from_file(tmpdir) -> None: # language=typescript content = """ import { x, y, z } from './some_file'; @@ -59,7 +59,7 @@ def test_function_calls_from_file(tmpdir): ] -def test_function_calls_from_class(tmpdir): +def test_function_calls_from_class(tmpdir) -> None: # language=typescript content = """ import { bar } from './some_file'; @@ -100,7 +100,7 @@ class A { @pytest.mark.skip(reason="CG-9422 TS method decorator parsing needs to be fixed first") -def test_function_calls_from_decorated_definitions(tmpdir): +def test_function_calls_from_decorated_definitions(tmpdir) -> None: # language=typescript content = """ import { describe, it } from 'jest'; @@ -139,7 +139,7 @@ class A { @pytest.mark.xfail(reason="Broken by function call changes") -def test_function_calls_from_datatypes(tmpdir): +def test_function_calls_from_datatypes(tmpdir) -> None: # language=typescript content = """ function getConfig(): { max_retries: number; timeout: number } { @@ -178,7 +178,7 @@ def test_function_calls_from_datatypes(tmpdir): ] -def test_function_calls_from_function_parameters(tmpdir): +def test_function_calls_from_function_parameters(tmpdir) -> None: # language=typescript content = """ // function parameters @@ -200,7 +200,7 @@ def test_function_calls_from_function_parameters(tmpdir): ] -def test_function_calls_from_while_loop(tmpdir): +def test_function_calls_from_while_loop(tmpdir) -> None: # language=typescript content = """ // while loop conditions @@ -222,7 +222,7 @@ def test_function_calls_from_while_loop(tmpdir): assert file.function_calls[1].parent.statement_type == StatementType.EXPRESSION_STATEMENT -def test_function_calls_from_if_conditions(tmpdir): +def test_function_calls_from_if_conditions(tmpdir) -> None: # language=typescript content = """ // if conditions @@ -247,7 +247,7 @@ def test_function_calls_from_if_conditions(tmpdir): ] -def test_function_calls_for_nested_calls(tmpdir): +def test_function_calls_for_nested_calls(tmpdir) -> None: # language=typescript content = """ parent(nested()) @@ -265,7 +265,7 @@ def test_function_calls_for_nested_calls(tmpdir): ] -def test_function_calls_for_chained_calls(tmpdir): +def test_function_calls_for_chained_calls(tmpdir) -> None: # language=typescript content = """ parent().child().grandchild() @@ -287,7 +287,7 @@ def test_function_calls_for_chained_calls(tmpdir): assert fcalls[2].parent.parent == fcalls[1] -def test_function_calls_in_function_call(tmpdir): +def test_function_calls_in_function_call(tmpdir) -> None: # language=typescript content = """ describe("top level", () => { diff --git a/tests/unit/codegen/sdk/typescript/function/test_function_parameter_to_interface.py b/tests/unit/codegen/sdk/typescript/function/test_function_parameter_to_interface.py index 13b2100bf..8ed0e838f 100644 --- a/tests/unit/codegen/sdk/typescript/function/test_function_parameter_to_interface.py +++ b/tests/unit/codegen/sdk/typescript/function/test_function_parameter_to_interface.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_convert_simple_props_to_interface(tmpdir): +def test_convert_simple_props_to_interface(tmpdir) -> None: # language=typescript content = """ function Button(props: { @@ -40,7 +40,7 @@ def test_convert_simple_props_to_interface(tmpdir): ) -def test_convert_props_to_interface_with_complex_types(tmpdir): +def test_convert_props_to_interface_with_complex_types(tmpdir) -> None: # language=typescript content = """ function DataGrid(props: { @@ -75,7 +75,7 @@ def test_convert_props_to_interface_with_complex_types(tmpdir): ) -def test_convert_props_to_interface_with_generics(tmpdir): +def test_convert_props_to_interface_with_generics(tmpdir) -> None: # language=typescript content = """ function List(props: { @@ -111,7 +111,7 @@ def test_convert_props_to_interface_with_generics(tmpdir): ) -def test_convert_props_to_interface_with_extends(tmpdir): +def test_convert_props_to_interface_with_extends(tmpdir) -> None: # language=typescript content = """ function CustomButton(props: { diff --git a/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py b/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py index 33527dab5..773accab5 100644 --- a/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py +++ b/tests/unit/codegen/sdk/typescript/import_resolution/test_is_dynamic.py @@ -2,7 +2,7 @@ from codegen.sdk.enums import ProgrammingLanguage -def test_ts_import_is_dynamic_in_function_declaration(tmpdir): +def test_ts_import_is_dynamic_in_function_declaration(tmpdir) -> None: # language=typescript content = """ import { staticImport } from './static'; @@ -21,7 +21,7 @@ def test_ts_import_is_dynamic_in_function_declaration(tmpdir): assert imports[1].is_dynamic # dynamic import in function -def test_ts_import_is_dynamic_in_method_definition(tmpdir): +def test_ts_import_is_dynamic_in_method_definition(tmpdir) -> None: # language=typescript content = """ import { Component } from '@angular/core'; @@ -46,7 +46,7 @@ class MyComponent { assert imports[2].is_dynamic # dynamic import in decorated method -def test_ts_import_is_dynamic_in_arrow_function(tmpdir): +def test_ts_import_is_dynamic_in_arrow_function(tmpdir) -> None: # language=typescript content = """ import { useState } from 'react'; @@ -67,7 +67,7 @@ def test_ts_import_is_dynamic_in_arrow_function(tmpdir): assert imports[1].is_dynamic # dynamic import in async arrow function -def test_ts_import_is_dynamic_in_if_statement(tmpdir): +def test_ts_import_is_dynamic_in_if_statement(tmpdir) -> None: # language=typescript content = """ import { isFeatureEnabled } from './utils'; @@ -91,7 +91,7 @@ def test_ts_import_is_dynamic_in_if_statement(tmpdir): assert imports[2].is_dynamic # dynamic import in else block -def test_ts_import_is_dynamic_in_try_statement(tmpdir): +def test_ts_import_is_dynamic_in_try_statement(tmpdir) -> None: # language=typescript content = """ import { logger } from './logger'; @@ -112,7 +112,7 @@ def test_ts_import_is_dynamic_in_try_statement(tmpdir): assert imports[1].is_dynamic # dynamic import in try block -def test_ts_import_is_dynamic_in_catch_clause(tmpdir): +def test_ts_import_is_dynamic_in_catch_clause(tmpdir) -> None: # language=typescript content = """ import { logger } from './logger'; @@ -132,7 +132,7 @@ def test_ts_import_is_dynamic_in_catch_clause(tmpdir): assert imports[1].is_dynamic # dynamic import in catch block -def test_ts_import_is_dynamic_in_finally_clause(tmpdir): +def test_ts_import_is_dynamic_in_finally_clause(tmpdir) -> None: # language=typescript content = """ import { logger } from './logger'; @@ -154,7 +154,7 @@ def test_ts_import_is_dynamic_in_finally_clause(tmpdir): assert imports[1].is_dynamic # dynamic import in finally block -def test_ts_import_is_dynamic_in_while_statement(tmpdir): +def test_ts_import_is_dynamic_in_while_statement(tmpdir) -> None: # language=typescript content = """ import { condition } from './utils'; @@ -172,7 +172,7 @@ def test_ts_import_is_dynamic_in_while_statement(tmpdir): assert imports[1].is_dynamic # dynamic import in while loop -def test_ts_import_is_dynamic_in_for_statement(tmpdir): +def test_ts_import_is_dynamic_in_for_statement(tmpdir) -> None: # language=typescript content = """ import { items } from './data'; @@ -190,7 +190,7 @@ def test_ts_import_is_dynamic_in_for_statement(tmpdir): assert imports[1].is_dynamic # dynamic import in for loop -def test_ts_import_is_dynamic_in_do_statement(tmpdir): +def test_ts_import_is_dynamic_in_do_statement(tmpdir) -> None: # language=typescript content = """ import { shouldContinue } from './utils'; @@ -208,7 +208,7 @@ def test_ts_import_is_dynamic_in_do_statement(tmpdir): assert imports[1].is_dynamic # dynamic import in do-while loop -def test_ts_import_is_dynamic_in_switch_statement(tmpdir): +def test_ts_import_is_dynamic_in_switch_statement(tmpdir) -> None: # language=typescript content = """ import { getFeatureFlag } from './utils'; diff --git a/tests/unit/codegen/sdk/typescript/react/test_interface_props.py b/tests/unit/codegen/sdk/typescript/react/test_interface_props.py index 8afaa33c3..8011f97f9 100644 --- a/tests/unit/codegen/sdk/typescript/react/test_interface_props.py +++ b/tests/unit/codegen/sdk/typescript/react/test_interface_props.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_convert_props_to_interface_with_proptypes_oneof(tmpdir): +def test_convert_props_to_interface_with_proptypes_oneof(tmpdir) -> None: # language=typescript content = """ import { PropTypes } from 'react'; @@ -53,7 +53,7 @@ def test_convert_props_to_interface_with_proptypes_oneof(tmpdir): ) -def test_convert_class_props_to_interface_with_proptypes_oneof(tmpdir): +def test_convert_class_props_to_interface_with_proptypes_oneof(tmpdir) -> None: # language=typescript content = """ import { PropTypes } from 'react'; @@ -103,7 +103,7 @@ class Component extends React.Component { ) -def test_convert_simple_function_props_to_interface(tmpdir): +def test_convert_simple_function_props_to_interface(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from 'react'; @@ -140,7 +140,7 @@ def test_convert_simple_function_props_to_interface(tmpdir): ) -def test_convert_complex_function_props_to_interface(tmpdir): +def test_convert_complex_function_props_to_interface(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from 'react'; @@ -210,7 +210,7 @@ def test_convert_complex_function_props_to_interface(tmpdir): ) -def test_convert_simple_class_props_to_interface(tmpdir): +def test_convert_simple_class_props_to_interface(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from 'react'; @@ -255,7 +255,7 @@ class Component extends React.Component { ) -def test_convert_function_props_with_inferred_parameters_to_interface(tmpdir): +def test_convert_function_props_with_inferred_parameters_to_interface(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from 'react'; @@ -316,7 +316,7 @@ def test_convert_function_props_with_inferred_parameters_to_interface(tmpdir): ) -def test_convert_class_props_with_inferred_parameters_to_interface(tmpdir): +def test_convert_class_props_with_inferred_parameters_to_interface(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from 'react'; @@ -383,7 +383,7 @@ class Component extends React.Component { ) -def test_convert_props_to_interface_with_any_params(tmpdir): +def test_convert_props_to_interface_with_any_params(tmpdir) -> None: # language=typescript content = """ import { PropTypes } from 'react'; @@ -420,7 +420,7 @@ def test_convert_props_to_interface_with_any_params(tmpdir): ) -def test_convert_class_props_to_interface_with_any_params(tmpdir): +def test_convert_class_props_to_interface_with_any_params(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from 'react'; @@ -465,7 +465,7 @@ class Component extends React.Component { ) -def test_convert_complex_class_props_to_interface(tmpdir): +def test_convert_complex_class_props_to_interface(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from "react"; @@ -549,7 +549,7 @@ class Component extends React.Component { ) -def test_convert_props_to_interface_with_nested_structures(tmpdir): +def test_convert_props_to_interface_with_nested_structures(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from 'react'; @@ -636,7 +636,7 @@ def test_convert_props_to_interface_with_nested_structures(tmpdir): ) -def test_convert_class_props_to_interface_with_nested_structures(tmpdir): +def test_convert_class_props_to_interface_with_nested_structures(tmpdir) -> None: # language=typescript content = """ import React, { PropTypes } from 'react'; @@ -737,7 +737,7 @@ class OtherComponent extends React.Component { ) -def test_convert_props_to_interface_with_proptypes_node_and_element(tmpdir): +def test_convert_props_to_interface_with_proptypes_node_and_element(tmpdir) -> None: # language=typescript content = """ import { PropTypes } from 'react'; diff --git a/tests/unit/codegen/sdk/typescript/react/test_ts_declassify.py b/tests/unit/codegen/sdk/typescript/react/test_ts_declassify.py index dc4f975a1..b5284739f 100644 --- a/tests/unit/codegen/sdk/typescript/react/test_ts_declassify.py +++ b/tests/unit/codegen/sdk/typescript/react/test_ts_declassify.py @@ -7,7 +7,7 @@ from codegen.sdk.typescript.file import TSFile -def test_declassify_class_component(tmpdir): +def test_declassify_class_component(tmpdir) -> None: # language=typescript content = """ import React from "react"; diff --git a/tests/unit/codegen/sdk/typescript/statements/if_block_statement/test_if_block_reduce_block.py b/tests/unit/codegen/sdk/typescript/statements/if_block_statement/test_if_block_reduce_block.py index 8cc9b1ee7..102cd8055 100644 --- a/tests/unit/codegen/sdk/typescript/statements/if_block_statement/test_if_block_reduce_block.py +++ b/tests/unit/codegen/sdk/typescript/statements/if_block_statement/test_if_block_reduce_block.py @@ -10,7 +10,7 @@ from codegen.sdk.typescript.file import TSFile -def test_reduce_condition_to_true_elif(tmpdir): +def test_reduce_condition_to_true_elif(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -46,7 +46,7 @@ def test_reduce_condition_to_true_elif(tmpdir): ) -def test_reduce_condition_to_false_elif(tmpdir): +def test_reduce_condition_to_false_elif(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -80,7 +80,7 @@ def test_reduce_condition_to_false_elif(tmpdir): ) -def test_reduce_condition_to_true_else(tmpdir): +def test_reduce_condition_to_true_else(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -109,7 +109,7 @@ def test_reduce_condition_to_true_else(tmpdir): ) -def test_reduce_condition_to_false_else(tmpdir): +def test_reduce_condition_to_false_else(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -138,7 +138,7 @@ def test_reduce_condition_to_false_else(tmpdir): ) -def test_reduce_condition_multiple_if_blocks(tmpdir): +def test_reduce_condition_multiple_if_blocks(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -176,7 +176,7 @@ def test_reduce_condition_multiple_if_blocks(tmpdir): ) -def test_reduce_condition_nested_if(tmpdir): +def test_reduce_condition_nested_if(tmpdir) -> None: # language=typescript content = """ class MyClass { @@ -228,7 +228,7 @@ class MyClass { ) -def test_reduce_condition_else_if_to_true(tmpdir): +def test_reduce_condition_else_if_to_true(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -269,7 +269,7 @@ def test_reduce_condition_else_if_to_true(tmpdir): ) -def test_reduce_condition_else_if_to_false(tmpdir): +def test_reduce_condition_else_if_to_false(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -307,7 +307,7 @@ def test_reduce_condition_else_if_to_false(tmpdir): ) -def test_reduce_condition_else_statement_raises(tmpdir): +def test_reduce_condition_else_statement_raises(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -328,7 +328,7 @@ def test_reduce_condition_else_statement_raises(tmpdir): else_block.reduce_condition(True) -def test_reduce_condition_to_true_single_if(tmpdir): +def test_reduce_condition_to_true_single_if(tmpdir) -> None: # language=typescript content = """ function foo(): void { @@ -355,7 +355,7 @@ def test_reduce_condition_to_true_single_if(tmpdir): ) -def test_reduce_condition_to_false_single_if(tmpdir): +def test_reduce_condition_to_false_single_if(tmpdir) -> None: # language=typescript content = """ function foo(): void { diff --git a/tests/unit/codegen/shared/compilation/test_codeblock_validation.py b/tests/unit/codegen/shared/compilation/test_codeblock_validation.py index 0f69ad463..af0e8593e 100644 --- a/tests/unit/codegen/shared/compilation/test_codeblock_validation.py +++ b/tests/unit/codegen/shared/compilation/test_codeblock_validation.py @@ -4,7 +4,7 @@ from codegen.shared.exceptions.compilation import DangerousUserCodeException -def test_no_dangerous_operations(): +def test_no_dangerous_operations() -> None: codeblock = """ print("not dangerous") """ @@ -14,7 +14,7 @@ def test_no_dangerous_operations(): pytest.fail("Unexpected DangerousPythonCodeError raised") -def test_dangerous_operations(): +def test_dangerous_operations() -> None: codeblock = """ print(os.environ["ENV"]) """ diff --git a/tests/unit/codegen/shared/compilation/test_function_compilation.py b/tests/unit/codegen/shared/compilation/test_function_compilation.py index da6ccf44a..7c786cc5b 100644 --- a/tests/unit/codegen/shared/compilation/test_function_compilation.py +++ b/tests/unit/codegen/shared/compilation/test_function_compilation.py @@ -4,7 +4,7 @@ from codegen.shared.exceptions.compilation import InvalidUserCodeException -def test_valid_func_str_should_not_raise(): +def test_valid_func_str_should_not_raise() -> None: func_str = """ from codegen.sdk.core.codebase import Codebase @@ -17,7 +17,7 @@ def execute(codebase: Codebase): pytest.fail("Unexpected InvalidUserCodeException raised") -def test_valid_func_str_with_nested_should_not_raise(): +def test_valid_func_str_with_nested_should_not_raise() -> None: func_str = """ from codegen.sdk.core.codebase import Codebase @@ -33,7 +33,7 @@ def nested(): pytest.fail("Unexpected InvalidUserCodeException raised") -def test_compile_syntax_error_indent_error_raises(): +def test_compile_syntax_error_indent_error_raises() -> None: func_str = """ def execute(codebase: Codebase): a = 1 @@ -47,7 +47,7 @@ def execute(codebase: Codebase): assert "> 3: a = 1" in error_msg -def test_compile_syntax_error_raises(): +def test_compile_syntax_error_raises() -> None: func_str = """ def execute(codebase: Codebase): print "syntax error" @@ -61,7 +61,7 @@ def execute(codebase: Codebase): assert '> 3: print "syntax error"' in error_msg -def test_compile_non_syntax_error_unicode_error_raises(): +def test_compile_non_syntax_error_unicode_error_raises() -> None: func_str = """ def execute(codebase: Codebase): print("hello")\udcff @@ -78,7 +78,7 @@ def execute(codebase: Codebase): assert "'utf-8' codec can't encode character '\\udcff'" in error_msg -def test_exec_error_non_syntax_error_zero_division_raises(): +def test_exec_error_non_syntax_error_zero_division_raises() -> None: """This is to test that we're handling errors (ex: ZeroDivisionError) that are raised during `exec` properly. NOTE: this case wouldn't happen with an actual func_str from create_function_str_from_codeblock b/c the func_str would just take in a codebase. @@ -95,7 +95,7 @@ def execute(codebase: Codebase, exec_error: int = 1/0): assert "> 2: def execute(codebase: Codebase, exec_error: int = 1/0):" in error_msg -def test_exec_error_non_syntax_error_name_error_raises(): +def test_exec_error_non_syntax_error_name_error_raises() -> None: """This is to test that we're handling errors (ex: NameError) that are raised during `exec` properly. NOTE: this case wouldn't happen with an actual func_str from create_function_str_from_codeblock b/c the func_str would not have any patches. @@ -113,7 +113,7 @@ def execute(codebase: Codebase): assert '> 2: @patch("foo", return_value="bar")' in error_msg -def test_func_str_uses_custom_scope_var_does_not_raise(): +def test_func_str_uses_custom_scope_var_does_not_raise() -> None: """This tests if a func_str references a var that is included in custom scope, it will not raise a NameError. This is to test the case when a group of codemods is run and a later one relies on a local defined in a previous one. """ diff --git a/tests/unit/codegen/shared/compilation/test_function_construction.py b/tests/unit/codegen/shared/compilation/test_function_construction.py index 1ba1fe6c9..a3bfc55ab 100644 --- a/tests/unit/codegen/shared/compilation/test_function_construction.py +++ b/tests/unit/codegen/shared/compilation/test_function_construction.py @@ -3,7 +3,7 @@ from codegen.shared.compilation.function_construction import create_function_str_from_codeblock -def test_no_execute_func_wraps(): +def test_no_execute_func_wraps() -> None: codeblock = """ print(len(codebase.files)) """ @@ -18,7 +18,7 @@ def execute(codebase: Codebase, pr_options: PROptions | None = None, pr = None, ) -def test_func_name_already_exists(): +def test_func_name_already_exists() -> None: codeblock = """ def execute(codebase: Codebase): print(len(codebase.files)) @@ -27,7 +27,7 @@ def execute(codebase: Codebase): assert codeblock in func -def test_func_name_not_execute(): +def test_func_name_not_execute() -> None: codeblock = """ print(len(codebase.files)) """ @@ -42,7 +42,7 @@ def not_execute(codebase: Codebase, pr_options: PROptions | None = None, pr = No ) -def test_function_str_includes_imports(): +def test_function_str_includes_imports() -> None: codeblock = """ print(len(codebase.files)) """ diff --git a/tests/unit/codegen/shared/compilation/test_string_to_code.py b/tests/unit/codegen/shared/compilation/test_string_to_code.py index 7d71ab50e..cc92d0a66 100644 --- a/tests/unit/codegen/shared/compilation/test_string_to_code.py +++ b/tests/unit/codegen/shared/compilation/test_string_to_code.py @@ -8,7 +8,7 @@ from codegen.shared.exceptions.control_flow import StopCodemodException -def test_syntax_error_raises(): +def test_syntax_error_raises() -> None: codeblock = """ print "syntax error" """ @@ -20,7 +20,7 @@ def test_syntax_error_raises(): assert 'print "syntax error"' in error_msg -def test_print_os_environ_raises(): +def test_print_os_environ_raises() -> None: codeblock = """ print(os.environ["ENV"]) """ @@ -28,7 +28,7 @@ def test_print_os_environ_raises(): create_execute_function_from_codeblock(codeblock=codeblock) -def test_print_calls_codebase_log(): +def test_print_calls_codebase_log() -> None: """Test print is monkey patched to call codebase.log""" codeblock = """ print("actually codebase.log") @@ -40,7 +40,7 @@ def test_print_calls_codebase_log(): assert mock_log.call_args_list[0][0][0] == "actually codebase.log" -def test_set_custom_scope_does_not_raise(): +def test_set_custom_scope_does_not_raise() -> None: """Test if the custom scope is set and the codeblock uses a var defined in the scope, it does not raise a NameError.""" codeblock = """ print(local_a) @@ -53,7 +53,7 @@ def test_set_custom_scope_does_not_raise(): @patch("codegen.shared.compilation.string_to_code.logger") -def test_stop_codemod_execution_logs_and_raises(mock_logger): +def test_stop_codemod_execution_logs_and_raises(mock_logger) -> None: codeblock = """ local_a = "this is local_a" raise StopCodemodException("test exception") @@ -65,7 +65,7 @@ def test_stop_codemod_execution_logs_and_raises(mock_logger): mock_logger.info.call_args_list[1][0][0] == "Stopping codemod due to StopCodemodException: test exception" -def test_references_import_from_generated_imports_does_not_raise(): +def test_references_import_from_generated_imports_does_not_raise() -> None: codeblock = """ print(os.getcwd()) # test external import print(MessageType.GITHUB) # test gs private import @@ -77,7 +77,7 @@ def test_references_import_from_generated_imports_does_not_raise(): assert mock_log.call_count == 3 -def test_references_import_not_in_generated_imports_raises_runtime_error(): +def test_references_import_not_in_generated_imports_raises_runtime_error() -> None: codeblock = """ print(Chainable.__name__) """ @@ -90,7 +90,7 @@ def test_references_import_not_in_generated_imports_raises_runtime_error(): assert "> 1: print(Chainable.__name__)" in error_msg -def test_error_during_execution_raises_runtime_error(): +def test_error_during_execution_raises_runtime_error() -> None: codeblock = """ print(var_that_does_not_exist) """ @@ -105,7 +105,7 @@ def test_error_during_execution_raises_runtime_error(): @pytest.mark.xfail(reason="TODO(CG-9581): fix codeblocks with return statements") -def test_return_statement_still_returns_locals(): +def test_return_statement_still_returns_locals() -> None: """Test if there is a return statement in a customer code block, the function should still return the locals""" codeblock = """ local_a = "this is local_a" diff --git a/tests/unit/codemods/test_codemod.py b/tests/unit/codemods/test_codemod.py index a9e7b1949..6d00a026d 100644 --- a/tests/unit/codemods/test_codemod.py +++ b/tests/unit/codemods/test_codemod.py @@ -2,8 +2,8 @@ from codemods.codemod import Codemod -def test_codemod_execute(): - def sample_execute(codebase: Codebase): +def test_codemod_execute() -> None: + def sample_execute(codebase: Codebase) -> None: for file in codebase.files: file.content = "print('hello')" diff --git a/tests/unit/skills/implementations/ai_skills.py b/tests/unit/skills/implementations/ai_skills.py index 16829b5df..87f236511 100644 --- a/tests/unit/skills/implementations/ai_skills.py +++ b/tests/unit/skills/implementations/ai_skills.py @@ -25,7 +25,7 @@ class RefactorClass(Skill, ABC): @staticmethod @skill_impl(test_cases=[RefactorClassPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: my_class = codebase.get_symbol("MyClass", optional=True) if my_class is None: msg = "MyClass not found in codebase" @@ -49,7 +49,7 @@ class GenerateDocstrings(Skill, ABC): @staticmethod @skill_impl(test_cases=[GenerateDocstringsPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: for cls in codebase.classes: for method in cls.methods: new_docstring = codebase.ai( @@ -73,7 +73,7 @@ class WriteTest(Skill, ABC): @staticmethod @skill_impl(test_cases=[WriteTestPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: my_function = codebase.get_function("my_function", optional=True) if my_function is None: msg = "my_function not found in codebase" @@ -100,7 +100,7 @@ class RenameMethods(Skill, ABC): @staticmethod @skill_impl(test_cases=[RenameMethodsPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: for cls in codebase.classes: for method in cls.methods: new_name = codebase.ai(f"Create a better name for the method {method.name}.", target=method) diff --git a/tests/unit/skills/implementations/asyncify_skills.py b/tests/unit/skills/implementations/asyncify_skills.py index d6640adae..66c3bd65d 100644 --- a/tests/unit/skills/implementations/asyncify_skills.py +++ b/tests/unit/skills/implementations/asyncify_skills.py @@ -133,7 +133,7 @@ class AsyncifyFunctionSkill(Skill): """ @staticmethod - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: pass @staticmethod @@ -158,7 +158,7 @@ def python_skill_func(codebase: PyCodebaseType) -> callable: @staticmethod @skill_impl([SkillTestCase(files=ts_files)], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: func_to_convert = codebase.get_function("funcToConvert") convert_queue = deque([func_to_convert]) diff --git a/tests/unit/skills/implementations/code_block_skills.py b/tests/unit/skills/implementations/code_block_skills.py index 1001bb7bb..bcdb07b68 100644 --- a/tests/unit/skills/implementations/code_block_skills.py +++ b/tests/unit/skills/implementations/code_block_skills.py @@ -85,7 +85,7 @@ class UnwrapFunctionBody(Skill, ABC): @staticmethod @skill_impl([SkillTestCase(files=[SkillTestCasePyFile(input=py_input, output=py_output_unwrap_function)])], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Unwraps the body of all functions in the file""" # iterate through all functions in the codebase for function in codebase.functions: @@ -94,7 +94,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([SkillTestCase(files=[SkillTestCaseTSFile(input=ts_input, output=ts_output_unwrap_function)])], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Unwraps the body of all functions in the file""" # iterate through all functions in the codebase for function in codebase.functions: @@ -145,7 +145,7 @@ class UnwrapIfStatement(Skill, ABC): @staticmethod @skill_impl([SkillTestCase(files=[SkillTestCasePyFile(input=py_input, output=py_output_unwrap_if_statement)])], language=ProgrammingLanguage.PYTHON) @skill_impl([SkillTestCase(files=[SkillTestCaseTSFile(input=ts_input, output=ts_output_unwrap_if_statement)])], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: for file in codebase.files: for if_block in file.code_block.if_blocks: if_block.consequence_block.unwrap() @@ -174,7 +174,7 @@ class UnwrapWithStatement(Skill, ABC): @staticmethod @skill_impl([SkillTestCase(files=[SkillTestCasePyFile(input=py_input, output=py_output_unwrap_with_statement)])], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Unwraps the body of all with statements in the file""" # for all functions in the codebase for function in codebase.functions: @@ -185,7 +185,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], language=ProgrammingLanguage.TYPESCRIPT, ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """With Statements are not supported in TypeScript""" ... @@ -377,7 +377,7 @@ class ConvertStatementToArgument(Skill, ABC): @staticmethod @skill_impl([SkillTestCase(files=[SkillTestCasePyFile(input=py_convert_test_assertion_input, output=py_convert_test_assertion_output)])], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Transform test assertion statements into an argument to test functions that make a call to a http method.""" methods = ["post", "get", "put", "delete", "patch"] client_name = "client" @@ -415,7 +415,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([SkillTestCase(files=[SkillTestCaseTSFile(input=ts_convert_test_assertion_input, output=ts_convert_test_assertion_output)])], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Transform test assertion statements into an argument to test functions that make a call to a http method.""" methods = ["post", "get", "put", "delete", "patch"] client_name = "client" diff --git a/tests/unit/skills/implementations/comment_skills.py b/tests/unit/skills/implementations/comment_skills.py index 917662a27..07e03492e 100644 --- a/tests/unit/skills/implementations/comment_skills.py +++ b/tests/unit/skills/implementations/comment_skills.py @@ -52,7 +52,7 @@ class EslintCommentSkill(Skill): @staticmethod @skill_impl(ts_test_cases, language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: ESLINT_RULE = "@typescript-eslint/no-explicit-any" # Iterate over all files in the codebase diff --git a/tests/unit/skills/implementations/decorator_skills.py b/tests/unit/skills/implementations/decorator_skills.py index 549101580..7874c2847 100644 --- a/tests/unit/skills/implementations/decorator_skills.py +++ b/tests/unit/skills/implementations/decorator_skills.py @@ -44,7 +44,7 @@ class AddDecoratorToFunction(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Adds a decorator to each function or method in the codebase if they do not already have it.""" # get the decorator symbol decorator_symbol = codebase.get_symbol("my_decorator") @@ -72,7 +72,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """TODO: Implement this method @Rishi Desai""" ... @@ -134,7 +134,7 @@ class UpdateDocStringOfDecoratedMethods(Skill, ABC): language=ProgrammingLanguage.PYTHON, prompt="Update the docstring of class methods if it has a decarators containing `with_user` in its name by appending 'OPERATES ON USER DATA'.", ) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: for cls in codebase.classes: for method in cls.methods: if method.decorators and any(["with_user" in dec.name for dec in method.decorators]): @@ -146,7 +146,7 @@ def python_skill_func(codebase: CodebaseType): language=ProgrammingLanguage.TYPESCRIPT, prompt="Update the docstring of class methods if it has a decarators containing `withUser` in its name by appending 'OPERATES ON USER DATA'.", ) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: for cls in codebase.classes: for method in cls.methods: if method.decorators and any(["withUser" in dec.name for dec in method.decorators]): diff --git a/tests/unit/skills/implementations/eval_skills.py b/tests/unit/skills/implementations/eval_skills.py index 7e4a6b38e..f20922257 100644 --- a/tests/unit/skills/implementations/eval_skills.py +++ b/tests/unit/skills/implementations/eval_skills.py @@ -25,7 +25,7 @@ class AddCopyrightHeaderSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(filepath=EVAL_SKILLS_TEST_DIR / "sample_py_1")], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: r"""Adds the following header to all files in the codebase: 'Copyright (c) Codegen.\nAll rights reserved.\n\n'""" for file in codebase.files: # Adds header to the file. Note: comments are added @@ -33,7 +33,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([SkillTestCase.from_dir(EVAL_SKILLS_TEST_DIR / "sample_ts_1")], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: r"""Adds the following header to all files in the codebase: '// Copyright (c) Codegen.\n// All rights reserved.\n\n'""" for file in codebase.files: # Adds header to the file. Note: comments are added @@ -46,7 +46,7 @@ class MoveFooFunctionsSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(filepath=EVAL_SKILLS_TEST_DIR / "sample_py_2")], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Move all functions starting with 'foo' to foo.py.""" # get the foo.py file if it exists, otherwise create it Note: extension is included in the file name foo_file = codebase.get_file("foo.py") if codebase.has_file("foo.py") else codebase.create_file("foo.py") @@ -60,7 +60,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not Implemented""" ... @@ -71,7 +71,7 @@ class AddDecoratorToFooFunctionsSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(filepath=EVAL_SKILLS_TEST_DIR / "sample_py_3")], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Adds the following decorator to all functions starting with 'foo': '@decorator_function'.""" # get the decorator_function symbol decorator_symbol = codebase.get_symbol("decorator_function") @@ -90,7 +90,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not Implemented""" ... @@ -101,7 +101,7 @@ class RenameFooToBarSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(filepath=EVAL_SKILLS_TEST_DIR / "sample_py_4")], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Renames all functions starting with 'foo' to start with 'bar'.""" # for each function in the codebase for function in codebase.functions: @@ -112,7 +112,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not Implemented""" # for each function in the codebase for function in codebase.functions: @@ -128,7 +128,7 @@ class AddReturnTypeHintSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(filepath=EVAL_SKILLS_TEST_DIR / "sample_py_5")], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Adds an int return type hint to all functions starting with 'foo'.""" # for each function in the codebase for function in codebase.functions: @@ -139,7 +139,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not Implemented""" ... @@ -158,7 +158,7 @@ class MoveEnumsToSeparateFileSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(filepath=EVAL_SKILLS_TEST_DIR / "sample_py_7")], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Moves any enums within a file into a file called: `enums.py`. Creates the `enums.py` if it does not exist. If the original file only contains enums this skill renames it to `enums.py` """ @@ -188,7 +188,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not Implemented""" ... @@ -199,13 +199,13 @@ class UpdateOptionalTypeHintsSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(filepath=EVAL_SKILLS_TEST_DIR / "sample_py_8")], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Replaces Optional[type] with type | None in all functions.""" # pattern to match Optional[type] optional_type_pattern = re.compile(r"Optional\[(.*?)]") # update optional parameter type hints - def update_optional_parameter_type_hints(function: PyFunction): + def update_optional_parameter_type_hints(function: PyFunction) -> None: # for each parameter in the function for parameter in function.parameters: # if the parameter is typed @@ -219,7 +219,7 @@ def update_optional_parameter_type_hints(function: PyFunction): # update the parameter type hint parameter.set_type_annotation(new_type) - def update_optional_return_type_hints(function: PyFunction): + def update_optional_return_type_hints(function: PyFunction) -> None: # if the function has a return type if function.return_type: # get the old return type @@ -249,7 +249,7 @@ def update_optional_return_type_hints(function: PyFunction): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not Implemented""" ... @@ -265,7 +265,7 @@ class DeleteUnusedSymbolsSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(EVAL_SKILLS_TEST_DIR / "sample_ts_3")], language=ProgrammingLanguage.TYPESCRIPT) @skill_impl([SkillTestCase.from_dir(filepath=EVAL_SKILLS_TEST_DIR / "sample_py_6")], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # for each symbol in the codebase for symbol in codebase.symbols: # if the symbol has no usages @@ -290,7 +290,7 @@ class MarkInternalFunctionsSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(EVAL_SKILLS_TEST_DIR / "sample_ts_7")], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Marks all functions that are only used in the `app` directory as an internal function. Marks functions as internal by adding the @internal tag to the bottom of the docstring. """ @@ -325,7 +325,7 @@ def typescript_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Not implemented for Python""" ... @@ -343,7 +343,7 @@ class MoveNonDefaultExportedJSXComponentsSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase.from_dir(EVAL_SKILLS_TEST_DIR / "sample_ts_2")], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Moves all JSX components that are not exported by default into a new file that is in the same directory as the original file.""" # for each file in the codebase for file in codebase.files: @@ -376,6 +376,6 @@ def typescript_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Not implemented for Python""" ... diff --git a/tests/unit/skills/implementations/example_skills.py b/tests/unit/skills/implementations/example_skills.py index 1951a4419..b48f7475b 100644 --- a/tests/unit/skills/implementations/example_skills.py +++ b/tests/unit/skills/implementations/example_skills.py @@ -22,7 +22,7 @@ class SetGlobalVarValueSkill(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Set the value of all global variables to 2 if their current assigned literal is 1""" for file in codebase.files: for v in file.global_vars: @@ -31,7 +31,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not implemented for TypeScript""" ... @@ -48,7 +48,7 @@ class RenameGlobalVarSkill(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Rename all global variables named 'x' to 'y'.""" for file in codebase.files: for v in file.global_vars: @@ -57,7 +57,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not implemented for TypeScript""" ... @@ -136,7 +136,7 @@ class SkipAllTests(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Adds pytest.mark.skip decorator to all test functions with reason="This is a test".""" for file in codebase.files: for function in file.functions: @@ -152,7 +152,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not implemented for TypeScript""" ... @@ -173,7 +173,7 @@ class AddTypeHintsSkill(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Adds type hints to function parameters and return values.""" for function in codebase.functions: if not function.return_type: @@ -195,7 +195,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not implemented for TypeScript""" """Adds type hints to function parameters and return values.""" for function in codebase.functions: @@ -241,7 +241,7 @@ class RenameClassSkill(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Rename the class OldName => NewName""" old_name = "OldName" new_name = "NewName" @@ -252,7 +252,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not implemented for TypeScript""" ... @@ -311,7 +311,7 @@ class AddWrapperFunctionSkill(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Add a trivial wrapper function called `new_{function.name}` around each function and class method that just calls the original function.""" for file in codebase.files: for function in file.functions: @@ -327,7 +327,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not implemented for TypeScript""" ... @@ -363,7 +363,7 @@ class RemoveUnusedImportsSkill(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Removes unused import statements from the code.""" for file in codebase.files: for imp in file.imports: @@ -372,7 +372,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not implemented for TypeScript""" ... @@ -406,7 +406,7 @@ class AddDocstringsSkill(Skill, ABC): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Adds docstrings to all functions and methods.""" for file in codebase.files: for function in file.functions: @@ -416,6 +416,6 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Not implemented for TypeScript""" ... diff --git a/tests/unit/skills/implementations/export_skills.py b/tests/unit/skills/implementations/export_skills.py index a30fea49c..92f4d1625 100644 --- a/tests/unit/skills/implementations/export_skills.py +++ b/tests/unit/skills/implementations/export_skills.py @@ -63,7 +63,7 @@ class ExportSkills(Skill): @staticmethod @skill_impl(ts_test_cases, language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: """Convert default exports to named exports in TypeScript""" for file in codebase.files: for export in file.exports: diff --git a/tests/unit/skills/implementations/expressions/dict.py b/tests/unit/skills/implementations/expressions/dict.py index f8e38bbbb..59e88f5ba 100644 --- a/tests/unit/skills/implementations/expressions/dict.py +++ b/tests/unit/skills/implementations/expressions/dict.py @@ -18,7 +18,7 @@ class DictToSchema(Skill): @staticmethod @skill_impl(test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: PyCodebaseType): + def python_skill_func(codebase: PyCodebaseType) -> None: """Converts a dictionary into a Schema object. Converts the key value pairs into arguments for the constructor""" # iterate over all global vars for v in codebase.global_vars: diff --git a/tests/unit/skills/implementations/expressions/list.py b/tests/unit/skills/implementations/expressions/list.py index 8852e6a92..adfa66965 100644 --- a/tests/unit/skills/implementations/expressions/list.py +++ b/tests/unit/skills/implementations/expressions/list.py @@ -27,7 +27,7 @@ class AppendToGlobalList(Skill, ABC): @staticmethod @skill_impl(test_cases_append_py, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases_append_ts, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: """Append 2 to the global list variable 'a' in Python.""" a = codebase.get_symbol("a") a.value.append("2") @@ -58,7 +58,7 @@ class RemoveFromGlobalList(Skill, ABC): @staticmethod @skill_impl(test_cases_remove_py, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases_remove_ts, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: """Remove 2 from the global list variable 'a'""" a = codebase.get_symbol("a", optional=True) assert a, "Symbol 'a' not found" @@ -87,7 +87,7 @@ class ClearGlobalList(Skill, ABC): @staticmethod @skill_impl(test_cases_clear_py, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases_clear_ts, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: """Clear the global list variable 'a'""" a = codebase.get_symbol("a") a.value.clear() diff --git a/tests/unit/skills/implementations/expressions/type.py b/tests/unit/skills/implementations/expressions/type.py index fb9eecd67..3eb1cca75 100644 --- a/tests/unit/skills/implementations/expressions/type.py +++ b/tests/unit/skills/implementations/expressions/type.py @@ -30,7 +30,7 @@ class AppendTypeToUnionTypeSkill(Skill): @staticmethod @skill_impl(test_cases_append_py, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """If the type of 'a' is a UnionType, append "str" to it if it doesn't already exist""" a: PyAssignment = codebase.get_symbol("a") if isinstance(a.type, UnionType): @@ -39,7 +39,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases_append_ts, language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """If the type of 'a' is a UnionType, append "str" to it if it doesn't already exist""" a: TSAssignment = codebase.get_symbol("a") if isinstance(a.type, UnionType): @@ -73,7 +73,7 @@ class ConvertToBuiltInTypeSkill(Skill): @staticmethod @skill_impl([SkillTestCase(files=[SkillTestCasePyFile(input=built_in_type_input, output=built_in_type_output)])], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: """Replaces type annotations using typing module with builtin types.""" import_replacements = {"List": "list", "Dict": "dict", "Set": "set", "Tuple": "tuple"} @@ -92,6 +92,6 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl([], language=ProgrammingLanguage.TYPESCRIPT, ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """The typing package is only available in Python""" ... diff --git a/tests/unit/skills/implementations/feature_flag_deletion_skill.py b/tests/unit/skills/implementations/feature_flag_deletion_skill.py index 17744025f..303d13693 100644 --- a/tests/unit/skills/implementations/feature_flag_deletion_skill.py +++ b/tests/unit/skills/implementations/feature_flag_deletion_skill.py @@ -101,7 +101,7 @@ class DeleteRolledOutFeatureFlagSkill(Skill, ABC): @staticmethod @skill_impl([SkillTestCase(files=py_files)], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: PyCodebaseType): + def python_skill_func(codebase: PyCodebaseType) -> None: """Implements the feature flag deletion for Python codebases.""" feature_flag = codebase.get_symbol("ROLLED_OUT_FLAG_TO_DELETE") feature_flag_name = feature_flag.name @@ -123,7 +123,7 @@ def python_skill_func(codebase: PyCodebaseType): @staticmethod @skill_impl([SkillTestCase(files=ts_files)], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: """Implements the feature flag deletion for TypeScript codebases.""" feature_flag = codebase.get_symbol("ROLLED_OUT_FLAG_TO_DELETE") feature_flag_name = feature_flag.name diff --git a/tests/unit/skills/implementations/function_skills.py b/tests/unit/skills/implementations/function_skills.py index 0168af72a..22403bde0 100644 --- a/tests/unit/skills/implementations/function_skills.py +++ b/tests/unit/skills/implementations/function_skills.py @@ -55,14 +55,14 @@ class AppendParameterSkill(Skill, ABC): @staticmethod @skill_impl(py_test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: PyCodebaseType): + def python_skill_func(codebase: PyCodebaseType) -> None: """Append a parameter to the function signature in Python""" foo = codebase.get_symbol("foo") foo.parameters.append("b: int") @staticmethod @skill_impl(ts_test_cases, language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: """Append a parameter to the function signature in TypeScript""" foo = codebase.get_symbol("foo") foo.parameters.append("b: number") diff --git a/tests/unit/skills/implementations/global_var_skills.py b/tests/unit/skills/implementations/global_var_skills.py index bf38f9873..b89fa154e 100644 --- a/tests/unit/skills/implementations/global_var_skills.py +++ b/tests/unit/skills/implementations/global_var_skills.py @@ -92,7 +92,7 @@ class DeleteUnusedLoggerSkill(Skill, ABC): @staticmethod @skill_impl(py_test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: PyCodebaseType): + def python_skill_func(codebase: PyCodebaseType) -> None: """Remove all global variables that are defined as `logger = get_logger(__name__)` if they are unused""" for file in codebase.files: for var in file.global_vars: @@ -101,7 +101,7 @@ def python_skill_func(codebase: PyCodebaseType): @staticmethod @skill_impl(ts_test_cases, language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: """Remove all global variables that are defined as `logger = getLogger()` if they are unused""" for file in codebase.files: for var in file.global_vars: diff --git a/tests/unit/skills/implementations/graph_viz/graph_viz_app_imports.py b/tests/unit/skills/implementations/graph_viz/graph_viz_app_imports.py index 4d4a4c11e..8af8bc793 100644 --- a/tests/unit/skills/implementations/graph_viz/graph_viz_app_imports.py +++ b/tests/unit/skills/implementations/graph_viz/graph_viz_app_imports.py @@ -29,7 +29,7 @@ class FileAppImportGraph(Skill, ABC): @staticmethod @skill_impl(test_cases=[FileAppImportGraphTest], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: import networkx as nx # Create a directed graph diff --git a/tests/unit/skills/implementations/graph_viz/graph_viz_call_graph.py b/tests/unit/skills/implementations/graph_viz/graph_viz_call_graph.py index 21bbd1a6f..059ad37c3 100644 --- a/tests/unit/skills/implementations/graph_viz/graph_viz_call_graph.py +++ b/tests/unit/skills/implementations/graph_viz/graph_viz_call_graph.py @@ -61,7 +61,7 @@ class CallGraphFromNode(Skill, ABC): @staticmethod @skill_impl(test_cases=[CallGraphFromNodeTest], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Create a directed graph G = nx.DiGraph() @@ -71,7 +71,7 @@ def skill_func(codebase: CodebaseType): # ===== [ Maximum Recursive Depth ] ===== MAX_DEPTH = 5 - def create_downstream_call_trace(parent: FunctionCall | Function | None = None, depth: int = 0): + def create_downstream_call_trace(parent: FunctionCall | Function | None = None, depth: int = 0) -> None: """Creates call graph for parent This function recurses through the call graph of a function and creates a visualization @@ -194,7 +194,7 @@ class CallGraphFilter(Skill, ABC): @staticmethod @skill_impl(test_cases=[CallGraphFilterTest], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Create a directed graph G = nx.DiGraph() @@ -212,7 +212,7 @@ def skill_func(codebase: CodebaseType): cls = codebase.get_class("MyClass") # Define a recursive function to traverse function calls - def create_filtered_downstream_call_trace(parent: FunctionCall | Function, current_depth, max_depth): + def create_filtered_downstream_call_trace(parent: FunctionCall | Function, current_depth, max_depth) -> None: if current_depth > max_depth: return @@ -289,7 +289,7 @@ class CallPathsBetweenNodes(Skill, ABC): @staticmethod @skill_impl(test_cases=[CallPathsBetweenNodesTest], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Create a directed graph G = nx.DiGraph() @@ -297,7 +297,7 @@ def skill_func(codebase: CodebaseType): MAX_DEPTH = 5 # Define a recursive function to traverse usages - def create_downstream_call_trace(parent: FunctionCall | Function, end: Callable, current_depth, max_depth): + def create_downstream_call_trace(parent: FunctionCall | Function, end: Callable, current_depth, max_depth) -> None: if current_depth > max_depth: return diff --git a/tests/unit/skills/implementations/graph_viz/graph_viz_dead_code.py b/tests/unit/skills/implementations/graph_viz/graph_viz_dead_code.py index 27b8ed1c3..5a6ed69b4 100644 --- a/tests/unit/skills/implementations/graph_viz/graph_viz_dead_code.py +++ b/tests/unit/skills/implementations/graph_viz/graph_viz_dead_code.py @@ -111,7 +111,7 @@ class DeadCode(Skill, ABC): @staticmethod @skill_impl(test_cases=[PyDeadCodeTest], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Create a directed graph to visualize dead and second-order dead code G = nx.DiGraph() diff --git a/tests/unit/skills/implementations/graph_viz/graph_viz_dir_tree.py b/tests/unit/skills/implementations/graph_viz/graph_viz_dir_tree.py index b5fec176b..d87674fb0 100644 --- a/tests/unit/skills/implementations/graph_viz/graph_viz_dir_tree.py +++ b/tests/unit/skills/implementations/graph_viz/graph_viz_dir_tree.py @@ -85,7 +85,7 @@ class RepoDirTree(Skill, ABC): @staticmethod @skill_impl(test_cases=[PyRepoDirTreeTest], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Create a directed graph G = nx.DiGraph() diff --git a/tests/unit/skills/implementations/graph_viz/graph_viz_foreign_key.py b/tests/unit/skills/implementations/graph_viz/graph_viz_foreign_key.py index 608487257..e57fc8f08 100644 --- a/tests/unit/skills/implementations/graph_viz/graph_viz_foreign_key.py +++ b/tests/unit/skills/implementations/graph_viz/graph_viz_foreign_key.py @@ -97,7 +97,7 @@ class ForeignKeyGraph(Skill, ABC): @staticmethod @skill_impl(test_cases=[PyForeignKeyGraphTest], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Create a mapping dictionary to hold relationships foreign_key_mapping = {} diff --git a/tests/unit/skills/implementations/guides/code-quality-metrics/import-loops.py b/tests/unit/skills/implementations/guides/code-quality-metrics/import-loops.py index ce61479ec..b349a45a4 100644 --- a/tests/unit/skills/implementations/guides/code-quality-metrics/import-loops.py +++ b/tests/unit/skills/implementations/guides/code-quality-metrics/import-loops.py @@ -87,7 +87,7 @@ class ImportCycleDetectionAndVisualization(Skill, ABC): @staticmethod @skill_impl(test_cases=[ImportCycleDetectionAndVisualizationTest], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: G: DiGraph = networkx.DiGraph() # iterate over all imports diff --git a/tests/unit/skills/implementations/guides/code-quality-metrics/large-files.py b/tests/unit/skills/implementations/guides/code-quality-metrics/large-files.py index e8f6eb3b4..6a250516c 100644 --- a/tests/unit/skills/implementations/guides/code-quality-metrics/large-files.py +++ b/tests/unit/skills/implementations/guides/code-quality-metrics/large-files.py @@ -71,7 +71,7 @@ class CountLargeFiles(Skill, ABC): @staticmethod @skill_impl(test_cases=[CountLargeFilesPyTest], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[CountLargeFilesTSTest], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: file_to_lines = {} for file in codebase.files: file_to_lines[file.filepath] = file.end_point[0] diff --git a/tests/unit/skills/implementations/guides/code-quality-metrics/missing-documentation.py b/tests/unit/skills/implementations/guides/code-quality-metrics/missing-documentation.py index a16dbe70c..f7d909229 100644 --- a/tests/unit/skills/implementations/guides/code-quality-metrics/missing-documentation.py +++ b/tests/unit/skills/implementations/guides/code-quality-metrics/missing-documentation.py @@ -69,7 +69,7 @@ class AutoDocstringGenerator(Skill, ABC): @staticmethod @skill_impl(test_cases=[AutoDocstringGeneratorTest], skip_test=True, language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: functions_without_docstring = 0 # Iterate over all functions in the codebase for function in codebase.functions: diff --git a/tests/unit/skills/implementations/guides/code-quality-metrics/unnamed-kwargs.py b/tests/unit/skills/implementations/guides/code-quality-metrics/unnamed-kwargs.py index a918c0ac7..ac3f92254 100644 --- a/tests/unit/skills/implementations/guides/code-quality-metrics/unnamed-kwargs.py +++ b/tests/unit/skills/implementations/guides/code-quality-metrics/unnamed-kwargs.py @@ -83,7 +83,7 @@ class CountUnnamedKeywordArguments(Skill, ABC): @staticmethod @skill_impl(test_cases=[CountUnnamedKeywordArgumentsTest], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: unnamed_kwargs_count = 0 # Iterate through all files in the codebase diff --git a/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-attributes.py b/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-attributes.py index b728735bc..95d954e82 100644 --- a/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-attributes.py +++ b/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-attributes.py @@ -120,7 +120,7 @@ class CountUntypedAttributesInCodebase(Skill, ABC): @staticmethod @skill_impl(test_cases=[CountUntypedAttributesInCodebasePyTest], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[CountUntypedAttributesInCodebaseTSTest], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: untyped_attributes_count = 0 typed_attributes_count = 0 diff --git a/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-parameters.py b/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-parameters.py index b179f7599..8b20fd185 100644 --- a/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-parameters.py +++ b/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-parameters.py @@ -76,7 +76,7 @@ class CountUntypedParameters(Skill, ABC): @staticmethod @skill_impl(test_cases=[CountUntypedParametersTest], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: untitled_parameters_count = 0 # Iterate through all files in the codebase diff --git a/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-return-types.py b/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-return-types.py index 82b492792..11187956c 100644 --- a/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-return-types.py +++ b/tests/unit/skills/implementations/guides/code-quality-metrics/untyped-return-types.py @@ -50,7 +50,7 @@ class CountUntypedReturnStatements(Skill, ABC): @staticmethod @skill_impl(test_cases=[CountUntypedReturnStatementsTest], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: untitled_return_count = 0 # Iterate through all files in the codebase diff --git a/tests/unit/skills/implementations/guides/codebase-ai.py b/tests/unit/skills/implementations/guides/codebase-ai.py index 4140257bc..251070d33 100644 --- a/tests/unit/skills/implementations/guides/codebase-ai.py +++ b/tests/unit/skills/implementations/guides/codebase-ai.py @@ -44,7 +44,7 @@ class MethodSummaryGenerator(Skill, ABC): @staticmethod @skill_impl(test_cases=[MethodSummaryGeneratorPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # get the method and its parent class method: Function = codebase.get_class("MyClass").get_method("target_method") parent_class = method.parent @@ -150,7 +150,7 @@ class FunctionDecomposition(Skill, ABC): @staticmethod @skill_impl(test_cases=[FunctionDecompositionPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Get the large function large_function = codebase.get_function("large_function") @@ -194,7 +194,7 @@ class DocstringGenerator(Skill, ABC): @staticmethod @skill_impl(test_cases=[DocstringGeneratorPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Loop through all functions in the codebase for function in codebase.functions: # Generate a new docstring for the function @@ -237,7 +237,7 @@ class DocstringFormatter(Skill, ABC): @staticmethod @skill_impl(test_cases=[DocstringFormatterPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Loop through all functions in the codebase for function in codebase.functions: # Convert the docstring to the new format diff --git a/tests/unit/skills/implementations/guides/complex-renaming.py b/tests/unit/skills/implementations/guides/complex-renaming.py index b3694b2aa..4c2970119 100644 --- a/tests/unit/skills/implementations/guides/complex-renaming.py +++ b/tests/unit/skills/implementations/guides/complex-renaming.py @@ -66,7 +66,7 @@ class RenameFunctionAndUpdateReferences(Skill, ABC): @staticmethod @skill_impl(test_cases=[RenameFunctionAndUpdateReferencesPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: # Find the symbol to rename old_function = codebase.get_file("path/to/file.py").get_function("old_function_name") @@ -75,7 +75,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[RenameFunctionAndUpdateReferencesTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: # Find the symbol to rename old_function = codebase.get_file("path/to/file.ts").get_function("old_function_name") @@ -99,7 +99,7 @@ class AutoRenameFunction(Skill, ABC): @staticmethod @skill_impl(test_cases=[RenameFunctionAndUpdateReferencesPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: function = codebase.get_file("path/to/file.py").get_function("old_function_name") function.rename("new_function_name") @@ -110,7 +110,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[RenameFunctionAndUpdateReferencesTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: function = codebase.get_file("path/to/file.ts").get_function("old_function_name") function.rename("new_function_name") @@ -189,7 +189,7 @@ class AutomaticClassRenamingWithReferenceUpdate(Skill, ABC): @staticmethod @skill_impl(test_cases=[AutoRenameClassPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[AutoRenameClassTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: old_class = codebase.get_class("OldClassName") old_class.rename("NewClassName") @@ -277,7 +277,7 @@ class RemoveDeprecatedPrefixfromFunctions(Skill, ABC): @staticmethod @skill_impl(test_cases=[RemoveDeprecatedPrefixfromFunctionsPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[RemoveDeprecatedPrefixfromFunctionsTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Iterate over all codebase functions for function in codebase.functions: # Filter for functions starting with deprecated_ @@ -398,7 +398,7 @@ class RenameMethodInBaseAndSubclasses(Skill, ABC): @staticmethod @skill_impl(test_cases=[RenameMethodPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[RenameMethodTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: base_class = codebase.get_class("BaseClass") old_method = base_class.get_method("old_method") diff --git a/tests/unit/skills/implementations/guides/creating-documentation.py b/tests/unit/skills/implementations/guides/creating-documentation.py index c7fa4cf3e..08a458f45 100644 --- a/tests/unit/skills/implementations/guides/creating-documentation.py +++ b/tests/unit/skills/implementations/guides/creating-documentation.py @@ -57,7 +57,7 @@ class UpdateFunctionDocstring(Skill, ABC): @staticmethod @skill_impl(test_cases=[UpdateFunctionDocstringPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: # Get a specific function function = codebase.get_file("path/to/file.py").get_function("my_function") # Update its docstring @@ -65,7 +65,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[UpdateFunctionDocstringTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: # Get a specific function function = codebase.get_file("path/to/file.ts").get_function("myFunction") # Update its docstring @@ -157,7 +157,7 @@ class CalculateDocumentationCoverage(Skill, ABC): @staticmethod @skill_impl(test_cases=[CalculateDocumentationCoveragePyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[CalculateDocumentationCoverageTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Initialize counters for total and documented functions/classes count_total = 0 count_documented = 0 @@ -230,7 +230,7 @@ class DocstringEnhancer(Skill, ABC): @staticmethod @skill_impl(test_cases=[DocstringEnhancerPyTestCase], skip_test=True, language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[DocstringEnhancerTSTestCase], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: for function in codebase.functions: current_docstring = function.docstring if current_docstring: @@ -363,7 +363,7 @@ class StaticDocstringGenerator(Skill, ABC): @staticmethod @skill_impl(test_cases=[StaticDocstringGeneratorPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[StaticDocstringGeneratorTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Iterate through all functions in the codebase for function in codebase.functions: # Check if the function has a docstring diff --git a/tests/unit/skills/implementations/guides/deleting-dead-code.py b/tests/unit/skills/implementations/guides/deleting-dead-code.py index c6f7a99c4..67a1bc47a 100644 --- a/tests/unit/skills/implementations/guides/deleting-dead-code.py +++ b/tests/unit/skills/implementations/guides/deleting-dead-code.py @@ -122,7 +122,7 @@ class RemoveUnusedSymbols(Skill, ABC): @staticmethod @skill_impl(test_cases=[RemoveUnusedSymbolsPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[RemoveUnusedSymbolsTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Iterate through all symbols in the codebase for symbol in codebase.symbols: # Check if the symbols has no usages @@ -195,7 +195,7 @@ class RemoveUnusedFunctions(Skill, ABC): @staticmethod @skill_impl(test_cases=[RemoveUnusedFunctionsPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[RemoveUnusedFunctionsTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Iterate through all functions in the codebase for function in codebase.functions: # Check if the function has no usages and no call sites @@ -262,7 +262,7 @@ class RemoveUnusedImports(Skill, ABC): @staticmethod @skill_impl(test_cases=[RemoveUnusedImportsPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[RemoveUnusedImportsTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Iterate through all imports in the codebase for import_stmt in codebase.imports: # Check if the import statement has no usages @@ -335,7 +335,7 @@ class RemoveUnusedLocalVariableAssignments(Skill, ABC): @staticmethod @skill_impl(test_cases=[RemoveUnusedLocalVariableAssignmentsPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[RemoveUnusedLocalVariableAssignmentsTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Iterate through all functions in the codebase for func in codebase.functions: # Iterate through local variable assignments in the function @@ -407,7 +407,7 @@ class RemoveUnusedParametersAndArguments(Skill, ABC): @staticmethod @skill_impl(test_cases=[RemoveUnusedParametersAndArgumentsPyTestCase], language=ProgrammingLanguage.PYTHON, ignore=True) @skill_impl(test_cases=[RemoveUnusedParametersAndArgumentsTSTestCase], language=ProgrammingLanguage.TYPESCRIPT, ignore=True) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # iterate through all functions in the codebase for function in codebase.functions: for param in function.parameters: @@ -492,7 +492,7 @@ class RemoveUnusedClasses(Skill, ABC): @staticmethod @skill_impl(test_cases=[RemoveUnusedClassesPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[RemoveUnusedClassesTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Iterate through all classes in the codebase for cls in codebase.classes: # Check if the class has no usages @@ -582,7 +582,7 @@ class CleanUpCodebase(Skill, ABC): @staticmethod @skill_impl(test_cases=[CleanUpCodebasePyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[CleanUpCodebaseTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # iterate through all files in the codebase for file in codebase.files: # Check if the file is empty diff --git a/tests/unit/skills/implementations/guides/increase-type-coverage.py b/tests/unit/skills/implementations/guides/increase-type-coverage.py index 52bf3a854..04e1c8f22 100644 --- a/tests/unit/skills/implementations/guides/increase-type-coverage.py +++ b/tests/unit/skills/implementations/guides/increase-type-coverage.py @@ -76,7 +76,7 @@ class CountTypedParametersSkill(Skill, ABC): @staticmethod @skill_impl(test_cases=[CountTypedParametersPyTestCase], language=ProgrammingLanguage.PYTHON) @skill_impl(test_cases=[CountTypedParametersTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Initialize counters for total parameters and typed parameters total_parameters = 0 typed_parameters = 0 @@ -200,7 +200,7 @@ class SetReturnTypeToNoneForFunctionsWithoutReturns(Skill, ABC): @staticmethod @skill_impl(test_cases=[SetReturnTypeToNoneForFunctionsWithoutReturnsPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: # Iterate through all files in the codebase for file in codebase.files: # Check if 'app' is in the file's filepath @@ -214,7 +214,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[SetReturnTypeToNoneForFunctionsWithoutReturnsTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: # Iterate through all files in the codebase for file in codebase.files: # Check if 'app' is in the file's filepath @@ -310,7 +310,7 @@ class ModifyReturnTypeSkill(Skill, ABC): @staticmethod @skill_impl(test_cases=[ModifyReturnTypeSkillPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify target_file = codebase.get_file("path/to/file.py") function = target_file.get_function("function_name") @@ -325,7 +325,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[ModifyReturnTypeSkillTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify target_file = codebase.get_file("path/to/file.ts") function = target_file.get_function("functionName") @@ -378,7 +378,7 @@ class ModifyReturnTypeWithNewParameter(Skill, ABC): @staticmethod @skill_impl(test_cases=[ModifyReturnTypeWithNewParameterTestCase], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify function = codebase.get_file("path/to/file.py").get_function("function_name") # def function_name() -> tuple[a, b]: ... @@ -419,7 +419,7 @@ class InspectFunctionReturnType(Skill, ABC): @staticmethod @skill_impl(test_cases=[InspectFunctionReturnTypePyTest], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Find the symbol to inspect function = codebase.get_file("path/to/file.py").get_function("function_name") # def function_name() -> tuple[a, b]: ... @@ -533,7 +533,7 @@ class InspectResolveFunctionReturnType(Skill, ABC): @staticmethod @skill_impl(test_cases=[InspectResolveFunctionReturnTypePySanityTest], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: # Find the symbol to inspect function = codebase.get_file("path/to/file.py").get_function("function_name") # from xyz import MyContainer, a, b @@ -544,7 +544,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[InspectResolveFunctionReturnTypeTSSanityTest], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: # Find the symbol to inspect function = codebase.get_file("path/to/file.ts").get_function("functionName") # import { MyContainer, a, b } './types' @@ -595,7 +595,7 @@ class ResolveAndRenameGlobalVariableType(Skill, ABC): @staticmethod @skill_impl(test_cases=[ResolveAndRenameGlobalVariableTypeTestCase], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: a = codebase.get_file("path/to/file.py").get_global_var("a") # a: MyType = ... print(a.type) # Resolves to MyType diff --git a/tests/unit/skills/implementations/guides/manipulating-collections.py b/tests/unit/skills/implementations/guides/manipulating-collections.py index d95c24bd7..0d68f8cce 100644 --- a/tests/unit/skills/implementations/guides/manipulating-collections.py +++ b/tests/unit/skills/implementations/guides/manipulating-collections.py @@ -54,7 +54,7 @@ class AddParameterToFunction(Skill, ABC): @staticmethod @skill_impl(test_cases=[AddParameterToFunctionPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify function = codebase.get_file("path/to/file.py").get_function("function_name") @@ -63,7 +63,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[AddParameterToFunctionTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify function = codebase.get_file("path/to/file.ts").get_function("functionName") @@ -116,7 +116,7 @@ class ModifyDictionaryValue(Skill, ABC): @staticmethod @skill_impl(test_cases=[ModifyDictionaryValuePyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: # var = {"example_key": "example_value"} # Find the symbol to modify var = codebase.get_file("path/to/file.py").get_global_var("foo") @@ -134,7 +134,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[ModifyDictionaryValueTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: # var = {"example_key": "example_value"} # Find the symbol to modify var = codebase.get_file("path/to/file.ts").get_global_var("foo") @@ -182,7 +182,7 @@ class ConvertVariableToSchema(Skill, ABC): @staticmethod @skill_impl(test_cases=[ConvertVariableToSchemaPyTestCase], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify var = codebase.get_file("path/to/file.py").get_global_var("var_name") @@ -232,7 +232,7 @@ class AppendToGlobalVariableList(Skill, ABC): @staticmethod @skill_impl(test_cases=[AppendToGlobalVariableListPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify var = codebase.get_file("path/to/file.py").get_global_var("foo") @@ -246,7 +246,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[AppendToGlobalVariableListTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify var = codebase.get_file("path/to/file.ts").get_global_var("foo") @@ -292,7 +292,7 @@ class CheckFunctionDecoratorPresence(Skill, ABC): @staticmethod @skill_impl(test_cases=[CheckFunctionDecoratorPresencePyTestCase], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Find the symbol to modify function = codebase.get_file("path/to/file.py").get_function("function_name") diff --git a/tests/unit/skills/implementations/guides/organize-your-codebase.py b/tests/unit/skills/implementations/guides/organize-your-codebase.py index b78330b51..d9e9eb79e 100644 --- a/tests/unit/skills/implementations/guides/organize-your-codebase.py +++ b/tests/unit/skills/implementations/guides/organize-your-codebase.py @@ -67,7 +67,7 @@ class SplitFunctionsIntoSeparateFiles(Skill, ABC): @staticmethod @skill_impl(test_cases=[SplitFunctionsIntoSeparateFilesPyTestCase], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Retrieve the Python file from the codebase file = codebase.get_file("path/to/file.py") # Iterate through the functions in the file @@ -147,7 +147,7 @@ class MoveSymbolDemonstration(Skill, ABC): @staticmethod @skill_impl(test_cases=[MoveSymbolDemonstrationPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: source_file = codebase.get_file("path/to/source_file.py") # =====[ Code Snippet ]===== # Get the symbol @@ -159,7 +159,7 @@ def python_skill_func(codebase: CodebaseType): @staticmethod @skill_impl(test_cases=[MoveSymbolDemonstrationTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: source_file = codebase.get_file("path/to/source_file.ts") # =====[ Code Snippet ]===== # Get the symbol @@ -247,7 +247,7 @@ class MoveSymbolWithUpdatedImports(Skill, ABC): @staticmethod @skill_impl(test_cases=[MoveSymbolWithUpdatedImportsPyTestCase], language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: CodebaseType): + def python_skill_func(codebase: CodebaseType) -> None: symbol_to_move = codebase.get_symbol("symbol_to_move") dst_file = codebase.create_file("new_file.py") symbol_to_move.move_to_file(dst_file, strategy="update_all_imports") @@ -337,7 +337,7 @@ class MoveSymbolWithAddBackEdgeStrategy(Skill, ABC): @staticmethod @skill_impl(test_cases=[MoveSymbolWithAddBackEdgeStrategyPyTestCase], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: symbol_to_move = codebase.get_symbol("symbol_to_move") dst_file = codebase.create_file("new_file.py") symbol_to_move.move_to_file(dst_file, strategy="add_back_edge") @@ -447,14 +447,14 @@ class MoveSymbolToFileWithDependencies(Skill, ABC): @staticmethod @skill_impl(test_cases=[MoveSymbolToFileWithDependenciesPyTestCase], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: my_symbol = codebase.get_symbol("my_symbol") dst_file = codebase.create_file("new_file.py") my_symbol.move_to_file(dst_file, include_dependencies=True) @staticmethod @skill_impl(test_cases=[MoveSymbolToFileWithDependenciesTSTestCase], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: my_symbol = codebase.get_symbol("mySymbol") dst_file = codebase.create_file("new_file.ts") my_symbol.move_to_file(dst_file, include_dependencies=True) @@ -526,7 +526,7 @@ class MoveSymbolsWithDependencies(Skill, ABC): @staticmethod @skill_impl(test_cases=[MoveSymbolsWithDependenciesPyTestCase], language=ProgrammingLanguage.PYTHON) - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: # Retrieve the source and destination files source_file = codebase.get_file("path/to/source_file.py") dest_file = codebase.get_file("path/to/destination_file.py") diff --git a/tests/unit/skills/implementations/if_statement_skills.py b/tests/unit/skills/implementations/if_statement_skills.py index aeb3b1991..33317ae28 100644 --- a/tests/unit/skills/implementations/if_statement_skills.py +++ b/tests/unit/skills/implementations/if_statement_skills.py @@ -93,7 +93,7 @@ class ReduceIfStatementConditionSkill(Skill, ABC): @staticmethod @skill_impl(py_test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: PyCodebaseType): + def python_skill_func(codebase: PyCodebaseType) -> None: """Simplifies the if/else control flow by reducing conditions that are `condition_to_set` to True""" for file in codebase.files: for function in file.functions: @@ -106,7 +106,7 @@ def python_skill_func(codebase: PyCodebaseType): @staticmethod @skill_impl(ts_test_cases, language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: """Simplifies the if/else control flow by reducing conditions that are `conditionToSet` to True""" for file in codebase.files: for function in file.functions: diff --git a/tests/unit/skills/implementations/move_symbols_skills.py b/tests/unit/skills/implementations/move_symbols_skills.py index a68b48eec..c6c182d1a 100644 --- a/tests/unit/skills/implementations/move_symbols_skills.py +++ b/tests/unit/skills/implementations/move_symbols_skills.py @@ -61,7 +61,7 @@ class MoveDataclassesSkills(Skill): @staticmethod @skill_impl(py_test_cases, language=ProgrammingLanguage.PYTHON) - def python_skill_func(codebase: PyCodebaseType): + def python_skill_func(codebase: PyCodebaseType) -> None: """Moves the dataclasses and adds a back edge to the moved symbol in the original file""" # Iterate over all files in the codebase for file in codebase.files: @@ -85,6 +85,6 @@ def python_skill_func(codebase: PyCodebaseType): @staticmethod @skill_impl([], language=ProgrammingLanguage.TYPESCRIPT, ignore=True) - def typescript_skill_func(codebase: CodebaseType): + def typescript_skill_func(codebase: CodebaseType) -> None: """Dataclasses is only available in Python""" ... diff --git a/tests/unit/skills/implementations/type_inheritance_skills.py b/tests/unit/skills/implementations/type_inheritance_skills.py index e6c3f54df..90057b306 100644 --- a/tests/unit/skills/implementations/type_inheritance_skills.py +++ b/tests/unit/skills/implementations/type_inheritance_skills.py @@ -224,12 +224,12 @@ def python_skill_func(codebase: PyCodebaseType) -> callable: pass @staticmethod - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: pass @staticmethod @skill_impl([SkillTestCase(files=ts_files_readonly)], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: """Given a type alias 'MyMapper', find all inherited or extended implementations of the type object. Loops through all codebase symbols and handles each symbol type accordingly. """ @@ -270,12 +270,12 @@ def python_skill_func(codebase: PyCodebaseType) -> callable: pass @staticmethod - def skill_func(codebase: CodebaseType): + def skill_func(codebase: CodebaseType) -> None: pass @staticmethod @skill_impl([SkillTestCase(files=ts_files_write)], language=ProgrammingLanguage.TYPESCRIPT) - def typescript_skill_func(codebase: TSCodebaseType): + def typescript_skill_func(codebase: TSCodebaseType) -> None: FUNC_NAME_TO_CONVERT = "convert" mapper_symbol: TypeAlias = codebase.get_symbol("MyMapper") @@ -331,7 +331,4 @@ def typescript_skill_func(codebase: TSCodebaseType): processed.add(f) if not f.is_async: f.asyncify() - - for call_site in f.call_sites: - if call_site.parent and isinstance(call_site.parent, Function): - funcs_to_asyncify.append(call_site.parent) + funcs_to_asyncify.extend(call_site.parent for call_site in f.call_sites if call_site.parent and isinstance(call_site.parent, Function)) diff --git a/tests/unit/skills/test_skills.py b/tests/unit/skills/test_skills.py index 11d83baca..f0c04e08a 100644 --- a/tests/unit/skills/test_skills.py +++ b/tests/unit/skills/test_skills.py @@ -13,7 +13,7 @@ # @pytest.pytestmark.skip(reason="AI skill testing") @pytest.mark.parametrize("skill, test_case", [(skill, case) for skill, case, _ in cases], ids=[skill.name + f"-{skill.language.name}-case-{case.name or idx}" for skill, case, idx in cases]) -def test_all_example_skills(tmpdir, skill: SkillImplementation, test_case: SkillTestCase, snapshot): +def test_all_example_skills(tmpdir, skill: SkillImplementation, test_case: SkillTestCase, snapshot) -> None: with get_codebase_session(tmpdir=tmpdir, programming_language=skill.language, files=test_case.to_input_dict(), repo_config=BaseRepoConfig(), verify_output=False, verify_input=False) as codebase: skill._skill_func(codebase) codebase.commit()