diff --git a/src/codegen/cli/codemod/convert.py b/src/codegen/cli/codemod/convert.py index fe5d62337..512eeebd4 100644 --- a/src/codegen/cli/codemod/convert.py +++ b/src/codegen/cli/codemod/convert.py @@ -7,7 +7,7 @@ def convert_to_cli(input: str, language: str, name: str) -> str: # from app.codemod.compilation.models.context import CodemodContext #from app.codemod.compilation.models.pr_options import PROptions -from graph_sitter import {codebase_type} +from codegen.sdk import {codebase_type} context: Any diff --git a/src/codegen/cli/utils/count_functions.py b/src/codegen/cli/utils/count_functions.py index 9c4543ffa..014c28a60 100644 --- a/src/codegen/cli/utils/count_functions.py +++ b/src/codegen/cli/utils/count_functions.py @@ -5,7 +5,7 @@ # from app.codemod.compilation.models.context import CodemodContext # from app.codemod.compilation.models.pr_options import PROptions -# from graph_sitter import PyCodebaseType +# from codegen.sdk import PyCodebaseType # context: CodemodContext diff --git a/src/codegen/git/__init__.py b/src/codegen/git/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codegen/git/configs/constants.py b/src/codegen/git/configs/constants.py index d8483c60a..3df55ca70 100644 --- a/src/codegen/git/configs/constants.py +++ b/src/codegen/git/configs/constants.py @@ -3,3 +3,5 @@ CODEGEN_BOT_NAME = "codegen-bot" CODEGEN_BOT_EMAIL = "team+codegenbot@codegen.sh" CODEOWNERS_FILEPATHS = [".github/CODEOWNERS", "CODEOWNERS", "docs/CODEOWNERS"] +HIGHSIDE_REMOTE_NAME = "highside" +LOWSIDE_REMOTE_NAME = "lowside" diff --git a/src/codegen/git/models/codemod_context.py b/src/codegen/git/models/codemod_context.py new file mode 100644 index 000000000..50d34478e --- /dev/null +++ b/src/codegen/git/models/codemod_context.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import logging +from typing import Any + +from pydantic import BaseModel, Field + +from codegen.git.models.pull_request_context import PullRequestContext + +logger = logging.getLogger(__name__) + + +class CodemodContext(BaseModel): + CODEMOD_ID: int | None = None + CODEMOD_LINK: str | None = None + CODEMOD_AUTHOR: str | None = None + TEMPLATE_ARGS: dict[str, Any] = Field(default_factory=dict) + + # TODO: add fields for version + # CODEMOD_VERSION_ID: int | None = None + # CODEMOD_VERSION_AUTHOR: str | None = None + + PULL_REQUEST: PullRequestContext | None = None + + @classmethod + def _render_template(cls, template_schema: dict[str, str], template_values: dict[str, Any]) -> dict[str, Any]: + template_data: dict[str, Any] = {} + for var_name, var_value in template_values.items(): + var_type = template_schema.get(var_name) + + if var_type == "list": + template_data[var_name] = [str(v).strip() for v in var_value.split(",")] + else: + template_data[var_name] = str(var_value) + return template_data diff --git a/src/codegen/git/models/github_named_user_context.py b/src/codegen/git/models/github_named_user_context.py new file mode 100644 index 000000000..b3eb22dd2 --- /dev/null +++ b/src/codegen/git/models/github_named_user_context.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + + +class GithubNamedUserContext(BaseModel): + """Represents a GitHub user parsed from a webhook payload""" + + login: str + email: str | None = None + + @classmethod + def from_payload(cls, payload: dict) -> "GithubNamedUserContext": + return cls(login=payload.get("login"), email=payload.get("email")) diff --git a/src/codegen/git/models/pr_options.py b/src/codegen/git/models/pr_options.py new file mode 100644 index 000000000..97f4bcd63 --- /dev/null +++ b/src/codegen/git/models/pr_options.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + +from codegen.utils.decorators.docs import apidoc + + +@apidoc +class PROptions(BaseModel): + """Options for generating a PR.""" + + title: str | None = None + body: str | None = None + labels: list[str] | None = None # TODO: not used until we add labels to GithubPullRequestModel + force_push_head_branch: bool | None = None diff --git a/src/codegen/git/models/pr_part_context.py b/src/codegen/git/models/pr_part_context.py new file mode 100644 index 000000000..162aed84b --- /dev/null +++ b/src/codegen/git/models/pr_part_context.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + + +class PRPartContext(BaseModel): + """Represents a GitHub pull request part parsed from a webhook payload""" + + ref: str + sha: str + + @classmethod + def from_payload(cls, payload: dict) -> "PRPartContext": + return cls(ref=payload.get("ref"), sha=payload.get("sha")) diff --git a/src/codegen/git/models/pull_request_context.py b/src/codegen/git/models/pull_request_context.py new file mode 100644 index 000000000..5bad57a4b --- /dev/null +++ b/src/codegen/git/models/pull_request_context.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel + +from codegen.git.models.github_named_user_context import GithubNamedUserContext +from codegen.git.models.pr_part_context import PRPartContext +from codegen.git.schemas.github import GithubType + + +class PullRequestContext(BaseModel): + """Represents a GitHub pull request""" + + id: int + url: str + html_url: str + number: int + state: str + title: str + user: GithubNamedUserContext + body: str + draft: bool + head: PRPartContext + base: PRPartContext + merged: bool | None + merged_by: dict | None + additions: int | None + deletions: int | None + changed_files: int | None + github_type: GithubType | None = None + webhook_data: dict | None = None + + @classmethod + def from_payload(cls, webhook_payload: dict) -> "PullRequestContext": + webhook_data = webhook_payload.get("pull_request", {}) + return cls( + id=webhook_data.get("id"), + url=webhook_data.get("url"), + html_url=webhook_data.get("html_url"), + number=webhook_data.get("number"), + state=webhook_data.get("state"), + title=webhook_data.get("title"), + user=GithubNamedUserContext.from_payload(webhook_data.get("user", {})), + body=webhook_data.get("body"), + draft=webhook_data.get("draft"), + head=PRPartContext.from_payload(webhook_data.get("head", {})), + base=PRPartContext.from_payload(webhook_data.get("base", {})), + merged=webhook_data.get("merged"), + merged_by=webhook_data.get("merged_by", {}), + additions=webhook_data.get("additions"), + deletions=webhook_data.get("deletions"), + changed_files=webhook_data.get("changed_files"), + github_type=GithubType.from_url(webhook_data.get("html_url")), + webhook_data=webhook_data, + ) diff --git a/src/codegen/git/repo_operator/repo_operator.py b/src/codegen/git/repo_operator/repo_operator.py index 25e8c8b23..1c8a9d3ef 100644 --- a/src/codegen/git/repo_operator/repo_operator.py +++ b/src/codegen/git/repo_operator/repo_operator.py @@ -18,7 +18,7 @@ from codegen.git.schemas.enums import CheckoutResult, FetchResult from codegen.git.schemas.repo_config import BaseRepoConfig from codegen.utils.performance.stopwatch_utils import stopwatch -from codegen.utils.time_utils import humanize_duration +from codegen.utils.performance.time_utils import humanize_duration logger = logging.getLogger(__name__) diff --git a/src/codegen/gscli/generate/runner_imports.py b/src/codegen/gscli/generate/runner_imports.py index 12ceeac41..5466f4feb 100644 --- a/src/codegen/gscli/generate/runner_imports.py +++ b/src/codegen/gscli/generate/runner_imports.py @@ -12,10 +12,10 @@ import plotly """.strip() CODEGEN_IMPORTS = """ -from app.codemod.compilation.models.context import CodemodContext -from app.codemod.compilation.models.github_named_user_context import GithubNamedUserContext -from app.codemod.compilation.models.pr_part_context import PRPartContext -from app.codemod.compilation.models.pull_request_context import PullRequestContext +from codegen.git.models.codemod_context import CodemodContext +from codegen.git.models.github_named_user_context import GithubNamedUserContext +from codegen.git.models.pr_part_context import PRPartContext +from codegen.git.models.pull_request_context import PullRequestContext """ # TODO: these should also be made public (i.e. included in the docs site) GS_PRIVATE_IMPORTS = """ diff --git a/src/codegen/gscli/generate/utils.py b/src/codegen/gscli/generate/utils.py index 8cca3f849..d579f9288 100644 --- a/src/codegen/gscli/generate/utils.py +++ b/src/codegen/gscli/generate/utils.py @@ -32,11 +32,11 @@ def generate_builtins_file(path_to_builtins: str, language_type: LanguageType): # This file is auto-generated, do not modify manually {{all_imports}} -from app.codemod.compilation.models.context import CodemodContext -from app.codemod.compilation.models.pr_options import PROptions -from app.codemod.compilation.models.github_named_user_context import GithubNamedUserContext -from app.codemod.compilation.models.pr_part_context import PRPartContext -from app.codemod.compilation.models.pull_request_context import PullRequestContext +from codegen.git.models.codemod_context import CodemodContext +from codegen.git.models.pr_options import PROptions +from codegen.git.models.github_named_user_context import GithubNamedUserContext +from codegen.git.models.pr_part_context import PRPartContext +from codegen.git.models.pull_request_context import PullRequestContext from codegen.sdk.codebase.flagging.code_flag import MessageType as MessageType {"\n".join(inspect.getsource(codebase).splitlines()[-2:])} diff --git a/src/codegen/sdk/codebase/flagging/flags.py b/src/codegen/sdk/codebase/flagging/flags.py index 817d0c25b..eef0a6686 100644 --- a/src/codegen/sdk/codebase/flagging/flags.py +++ b/src/codegen/sdk/codebase/flagging/flags.py @@ -1,14 +1,11 @@ from dataclasses import dataclass, field -from typing import TYPE_CHECKING from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.enums import MessageType +from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.core.interfaces.editable import Editable from codegen.utils.decorators.docs import noapidoc -if TYPE_CHECKING: - from app.codemod.types import Group - @dataclass class Flags: @@ -69,7 +66,7 @@ def set_find_mode(self, find_mode: bool) -> None: self._find_mode = find_mode @noapidoc - def set_active_group(self, group: "Group") -> None: + def set_active_group(self, group: Group) -> None: """Will only fix these flags.""" # TODO - flesh this out more with Group datatype and GroupBy self._active_group = group.flags diff --git a/src/codegen/sdk/codebase/flagging/group.py b/src/codegen/sdk/codebase/flagging/group.py new file mode 100644 index 000000000..58f6f9e95 --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/group.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +from dataclasses_json import dataclass_json + +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + +DEFAULT_GROUP_ID = 0 + + +@dataclass_json +@dataclass +class Group: + group_by: GroupBy + segment: str + flags: list[CodeFlag] | None = None + id: int = DEFAULT_GROUP_ID diff --git a/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py new file mode 100644 index 000000000..7eaac241f --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/all_grouper.py @@ -0,0 +1,21 @@ +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + + +class AllGrouper(BaseGrouper): + """Group all flags into one group.""" + + type: GroupBy = GroupBy.ALL + + @staticmethod + def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | None = None) -> list[Group]: + return [Group(group_by=GroupBy.ALL, segment="all", flags=flags)] if flags else [] + + @staticmethod + def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RemoteRepoOperator | None = None) -> Group: + if segment != "all": + raise ValueError(f"❌ Invalid segment for AllGrouper: {segment}. Only 'all' is a valid segment.") + return Group(group_by=GroupBy.ALL, segment=segment, flags=flags) diff --git a/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py new file mode 100644 index 000000000..ec8b9904d --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/app_grouper.py @@ -0,0 +1,34 @@ +import logging + +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + +logger = logging.getLogger(__name__) + + +class AppGrouper(BaseGrouper): + """Group flags by segment=app. + Ex: apps/profile. + """ + + type: GroupBy = GroupBy.APP + + @staticmethod + def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | None = None) -> list[Group]: + unique_apps = list({"/".join(flag.filepath.split("/")[:3]) for flag in flags}) + groups = [] + for idx, app in enumerate(unique_apps): + matches = [f for f in flags if f.filepath.startswith(app)] + if len(matches) > 0: + groups.append(Group(id=idx, group_by=GroupBy.APP, segment=app, flags=matches)) + return groups + + @staticmethod + def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RemoteRepoOperator | None = None) -> Group: + segment_flags = [f for f in flags if f.filepath.startswith(segment)] + if len(segment_flags) == 0: + logger.warning(f"🤷‍♀️ No flags found for APP segment: {segment}") + return Group(group_by=GroupBy.APP, segment=segment, flags=segment_flags) diff --git a/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py new file mode 100644 index 000000000..56ada1612 --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/base_grouper.py @@ -0,0 +1,26 @@ +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + + +class BaseGrouper: + """Base class of all groupers. + Children of this class should include in their doc string: + - a short desc of what the segment format is. ex: for FileGrouper the segment is a filename + """ + + type: GroupBy + + def __init__(self) -> None: + if type is None: + raise ValueError("Must set type in BaseGrouper") + + @staticmethod + def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | None = None) -> list[Group]: + raise NotImplementedError("Must implement create_all_groups in BaseGrouper") + + @staticmethod + def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RemoteRepoOperator | None = None) -> Group: + """TODO: handle the case when 0 flags are passed in""" + raise NotImplementedError("Must implement create_single_group in BaseGrouper") diff --git a/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py new file mode 100644 index 000000000..8e01f6ba1 --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/codeowner_grouper.py @@ -0,0 +1,40 @@ +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + +DEFAULT_CHUNK_SIZE = 5 + + +class CodeownerGrouper(BaseGrouper): + """Group flags by CODEOWNERS. + + Parses .github/CODEOWNERS and groups by each possible codeowners + + Segment should be either a github username or github team name. + """ + + type: GroupBy = GroupBy.CODEOWNER + + @staticmethod + def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | None = None) -> list[Group]: + owner_to_group: dict[str, Group] = {} + no_owner_group = Group(group_by=GroupBy.CODEOWNER, segment="@no-owner", flags=[]) + for idx, flag in enumerate(flags): + flag_owners = repo_operator.codeowners_parser.of(flag.filepath) # TODO: handle codeowners_parser could be null + if not flag_owners: + no_owner_group.flags.append(flag) + continue + # NOTE: always use the first owner. ex if the line is /dir @team1 @team2 then use team1 + flag_owner = flag_owners[0][1] + group = owner_to_group.get(flag_owner, Group(id=idx, group_by=GroupBy.CODEOWNER, segment=flag_owner, flags=[])) + group.flags.append(flag) + owner_to_group[flag_owner] = group + + no_owner_group.id = len(owner_to_group) + return [*list(owner_to_group.values()), no_owner_group] + + @staticmethod + def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RemoteRepoOperator | None = None) -> Group: + raise NotImplementedError("TODO: implement single group creation") diff --git a/src/codegen/sdk/codebase/flagging/groupers/constants.py b/src/codegen/sdk/codebase/flagging/groupers/constants.py new file mode 100644 index 000000000..2fc2a29ab --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/constants.py @@ -0,0 +1,15 @@ +from codegen.sdk.codebase.flagging.groupers.all_grouper import AllGrouper +from codegen.sdk.codebase.flagging.groupers.app_grouper import AppGrouper +from codegen.sdk.codebase.flagging.groupers.codeowner_grouper import CodeownerGrouper +from codegen.sdk.codebase.flagging.groupers.file_chunk_grouper import FileChunkGrouper +from codegen.sdk.codebase.flagging.groupers.file_grouper import FileGrouper +from codegen.sdk.codebase.flagging.groupers.instance_grouper import InstanceGrouper + +ALL_GROUPERS = [ + AllGrouper, + AppGrouper, + CodeownerGrouper, + FileChunkGrouper, + FileGrouper, + InstanceGrouper, +] diff --git a/src/codegen/sdk/codebase/flagging/groupers/enums.py b/src/codegen/sdk/codebase/flagging/groupers/enums.py new file mode 100644 index 000000000..c84b2f413 --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/enums.py @@ -0,0 +1,11 @@ +from enum import StrEnum + + +class GroupBy(StrEnum): + ALL = "all" + APP = "app" + CODEOWNER = "codeowner" + FILE = "file" + FILE_CHUNK = "file_chunk" + HOT_COLD = "hot_cold" + INSTANCE = "instance" diff --git a/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py new file mode 100644 index 000000000..abe62dd37 --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/file_chunk_grouper.py @@ -0,0 +1,47 @@ +import logging + +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +from codegen.utils.string.csv_utils import comma_separated_to_list, list_to_comma_separated + +logger = logging.getLogger(__name__) + +DEFAULT_CHUNK_SIZE = 5 + + +class FileChunkGrouper(BaseGrouper): + """Group flags by a chunk of files. + Ex: if chunk size is 10 then a Group only contains flags from max 10 unique files. + TODO: currently only supports a harcoded chunk size of 5. + + Segment is a comma separated list of filenames. + """ + + type: GroupBy = GroupBy.FILE_CHUNK + + @staticmethod + def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | None = None) -> list[Group]: + map = {f.filepath: f for f in flags} + filenames = sorted(map.keys()) + chunks = chunk_list(filenames, DEFAULT_CHUNK_SIZE) + groups = [] + for idx, chunk in enumerate(chunks): + chunk_flags = [map[filename] for filename in chunk] + groups.append(Group(id=idx, group_by=GroupBy.FILE_CHUNK, segment=list_to_comma_separated(chunk), flags=chunk_flags)) + return groups + + @staticmethod + def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RemoteRepoOperator | None = None) -> Group: + segment_filepaths = comma_separated_to_list(segment) + all_segment_flags = [f for f in flags if f.filepath in segment_filepaths] + if len(all_segment_flags) == 0: + logger.warning(f"🤷‍♀️ No flags found for FILE_CHUNK segment: {segment_filepaths}") + return Group(group_by=GroupBy.FILE_CHUNK, segment=segment, flags=all_segment_flags) + + +def chunk_list(lst: list, chk_size: int) -> list[list[str]]: + for index in range(0, len(lst), chk_size): + yield lst[index : index + chk_size] diff --git a/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py new file mode 100644 index 000000000..6cc537afa --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/file_grouper.py @@ -0,0 +1,33 @@ +import logging + +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + +logger = logging.getLogger(__name__) + + +class FileGrouper(BaseGrouper): + """Group flags by file. + Segment is the filename. + """ + + type: GroupBy = GroupBy.FILE + + @staticmethod + def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | None = None) -> list[Group]: + groups = [] + filenames = sorted(list({f.filepath for f in flags})) + for idx, filename in enumerate(filenames): + filename_flags = [flag for flag in flags if flag.filepath == filename] + groups.append(Group(id=idx, group_by=GroupBy.FILE, segment=filename, flags=filename_flags)) + return groups + + @staticmethod + def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RemoteRepoOperator | None = None) -> Group: + segment_flags = [flag for flag in flags if flag.filepath == segment] + if len(segment_flags) == 0: + logger.warning(f"🤷‍♀️ No flags found for FILE segment: {segment}") + return Group(group_by=GroupBy.FILE, segment=segment, flags=segment_flags) diff --git a/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py b/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py new file mode 100644 index 000000000..6607c58cc --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/instance_grouper.py @@ -0,0 +1,26 @@ +from codegen.git.repo_operator.remote_repo_operator import RemoteRepoOperator +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.group import Group +from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + + +class InstanceGrouper(BaseGrouper): + """Group flags by flags. haha + One Group per flag. + """ + + type: GroupBy = GroupBy.INSTANCE + + @staticmethod + def create_all_groups(flags: list[CodeFlag], repo_operator: RemoteRepoOperator | None = None) -> list[Group]: + return [Group(id=idx, group_by=GroupBy.INSTANCE, segment=f.hash, flags=[f]) for idx, f in enumerate(flags)] + + @staticmethod + def create_single_group(flags: list[CodeFlag], segment: str, repo_operator: RemoteRepoOperator | None = None) -> Group: + # TODO: not sure if it's possible to regenerate a group for instance grouper b/c it needs to re-generate/re-find the flag. might need to rely on the flag meta 🤦‍♀️ + try: + flag = CodeFlag.from_json(segment) + return Group(group_by=GroupBy.INSTANCE, segment=segment, flags=[flag]) + except Exception as e: + raise ValueError(f"Unable to deserialize segment ({segment}) into CodeFlag. Unable to create group.") diff --git a/src/codegen/sdk/codebase/flagging/groupers/utils.py b/src/codegen/sdk/codebase/flagging/groupers/utils.py new file mode 100644 index 000000000..64bc737c1 --- /dev/null +++ b/src/codegen/sdk/codebase/flagging/groupers/utils.py @@ -0,0 +1,13 @@ +from codegen.sdk.codebase.flagging.groupers.all_grouper import AllGrouper +from codegen.sdk.codebase.flagging.groupers.base_grouper import BaseGrouper +from codegen.sdk.codebase.flagging.groupers.constants import ALL_GROUPERS +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy + + +def get_grouper_by_group_by(group_by: GroupBy | None) -> type[BaseGrouper]: + if group_by is None: + return AllGrouper + matched_groupers = [x for x in ALL_GROUPERS if x.type == group_by] + if len(matched_groupers) > 0: + return matched_groupers[0] + raise ValueError(f"No grouper found for group_by={group_by}. Did you add to ALL_GROUPERS?") diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 0c018bc3a..c5a1fda09 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -8,7 +8,7 @@ from collections.abc import Generator from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Generic, Literal, TypeVar, Unpack, overload +from typing import Generic, Literal, TypeVar, Unpack, overload import plotly.graph_objects as go import rich.repr @@ -31,6 +31,7 @@ from codegen.sdk.codebase.diff_lite import DiffLite from codegen.sdk.codebase.flagging.code_flag import CodeFlag from codegen.sdk.codebase.flagging.enums import FlagKwargs +from codegen.sdk.codebase.flagging.group import Group from codegen.sdk.codebase.span import Span from codegen.sdk.core.assignment import Assignment from codegen.sdk.core.class_definition import Class @@ -74,9 +75,6 @@ from codegen.utils.performance.stopwatch_utils import stopwatch from codegen.visualizations.visualization_manager import VisualizationManager -if TYPE_CHECKING: - from app.codemod.types import Group - logger = logging.getLogger(__name__) MAX_LINES = 10000 # Maximum number of lines of text allowed to be logged @@ -890,7 +888,7 @@ def set_find_mode(self, find_mode: bool) -> None: self.G.flags.set_find_mode(find_mode) @noapidoc - def set_active_group(self, group: "Group") -> None: + def set_active_group(self, group: Group) -> None: """Will only fix these flags.""" # TODO - flesh this out more with Group datatype and GroupBy self.G.flags.set_active_group(group) diff --git a/src/codegen/utils/__init__.py b/src/codegen/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/codegen/utils/compilation/README.md b/src/codegen/utils/compilation/README.md new file mode 100644 index 000000000..626ceb7ea --- /dev/null +++ b/src/codegen/utils/compilation/README.md @@ -0,0 +1,8 @@ +Utils around compiling a user's codeblock into a function. + +This includes: +- Raising on any dangerous operations in the codeblock +- Catching and logging any compilation errors +- Monkey patching built-ins like print +- etc + diff --git a/src/codegen/utils/compilation/codeblock_validation.py b/src/codegen/utils/compilation/codeblock_validation.py new file mode 100644 index 000000000..c7737891e --- /dev/null +++ b/src/codegen/utils/compilation/codeblock_validation.py @@ -0,0 +1,13 @@ +import re + +from codegen.utils.compilation.exceptions import DangerousUserCodeException + + +def check_for_dangerous_operations(user_code: str) -> None: + """If codeblock has dangerous operations (ex: exec, os.environ, etc) then raise an error and prevent the user from executing it.""" + dangerous_operation_patterns = [ + r"\b(os\.environ|locals|globals)\b", # Environment variables and scope access + ] + pattern = "|".join(dangerous_operation_patterns) + if re.search(pattern, user_code, re.IGNORECASE): + raise DangerousUserCodeException("The codeblock contains potentially dangerous operations that are not allowed.") diff --git a/src/codegen/utils/compilation/exception_utils.py b/src/codegen/utils/compilation/exception_utils.py new file mode 100644 index 000000000..8b97d7a88 --- /dev/null +++ b/src/codegen/utils/compilation/exception_utils.py @@ -0,0 +1,53 @@ +import logging +from types import FrameType, TracebackType + +logger = logging.getLogger(__name__) + + +def get_offset_traceback(tb_lines: list[str], line_offset: int = 0, filenameFilter: str = "") -> str: + """Generate a traceback string with offset line numbers. + + :param tb_lines: lines output for the traceback + :param line_offset: Number of lines to offset the traceback + :return: A string containing the offset traceback + """ + # Process each line of the traceback + offset_tb_lines = [] + for line in tb_lines: + if line.lstrip().startswith("File"): + if line.lstrip().startswith(f'File "{filenameFilter}"') and "execute" not in line: + # This line contains file and line number information + parts = line.split(", line ") + if len(parts) > 1: + # Offset the line number + line_num = int(parts[1].split(",")[0]) + new_line_num = line_num - line_offset + line = f"{parts[0]}, line {new_line_num}{','.join(parts[1].split(',')[1:])}" + offset_tb_lines.append(line) + else: + offset_tb_lines.append(line) + + # Join the processed lines back into a single string + return "".join(offset_tb_lines) + + +def get_local_frame(exc_type: type[BaseException], exc_value: BaseException, exc_traceback: TracebackType) -> FrameType | None: + LOCAL_FILENAME = "" + LOCAL_MODULE_DIR = "codegen-backend/app/" + tb = exc_traceback + while tb and ((tb.tb_next and tb.tb_frame.f_code.co_filename != LOCAL_FILENAME) or LOCAL_MODULE_DIR in tb.tb_frame.f_code.co_filename): + tb = tb.tb_next + + frame = tb.tb_frame if tb else None + return frame + + +def get_local_frame_context(frame: FrameType): + local_vars = {k: v for k, v in frame.f_locals.items() if not k.startswith("__")} + if "print" in local_vars: + del local_vars["print"] + if "codebase" in local_vars: + del local_vars["codebase"] + if "pr_options" in local_vars: + del local_vars["pr_options"] + return local_vars diff --git a/src/codegen/utils/compilation/exceptions.py b/src/codegen/utils/compilation/exceptions.py new file mode 100644 index 000000000..e9a274e61 --- /dev/null +++ b/src/codegen/utils/compilation/exceptions.py @@ -0,0 +1,10 @@ +class UserCodeException(Exception): + """Custom exception for any issues in user code.""" + + +class DangerousUserCodeException(UserCodeException): + """Custom exception user code that has dangerous / not permitted operations.""" + + +class InvalidUserCodeException(UserCodeException): + """Custom exception for user code that can be compiled/executed. Ex: syntax errors, indentation errors, name errors etc.""" diff --git a/src/codegen/utils/compilation/function_compilation.py b/src/codegen/utils/compilation/function_compilation.py new file mode 100644 index 000000000..0e16cb19d --- /dev/null +++ b/src/codegen/utils/compilation/function_compilation.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import linecache +import logging +import sys +import traceback +from collections.abc import Callable + +from codegen.utils.compilation.exceptions import InvalidUserCodeException + +logger = logging.getLogger(__name__) + + +def get_compilation_error_context(filename: str, line_number: int, window_size: int = 2): + """Get lines of context around SyntaxError + Exceptions that occur when compiling functions.""" + start = max(1, line_number - window_size) + end = line_number + window_size + 1 + lines = [] + for i in range(start, end): + line = linecache.getline(filename, i).rstrip() + if line: + lines.append((i, line)) + return lines + + +def safe_compile_function_string(custom_scope: dict, func_name: str, func_str: str) -> Callable: + # =====[ Add function string to linecache ]===== + # (This is necessary for the traceback to work correctly) + linecache.cache[""] = (len(func_str), None, func_str.splitlines(True), "") + + # =====[ Compile & exec the code ]===== + # This will throw errors if there is invalid syntax + try: + # First, try to compile the code to catch syntax errors + logger.info(f"Compiling function: {func_name} ...") + compiled_code = compile(func_str, "", "exec") + # If compilation succeeds, try to execute the code + logger.info(f"Compilation succeeded. exec-ing function: {func_name} ...") + exec(compiled_code, custom_scope, custom_scope) + + # =====[ Catch SyntaxErrors ]===== + except SyntaxError as e: + error_class = e.__class__.__name__ + detail = e.args[0] + line_number = e.lineno + context_lines = get_compilation_error_context("", line_number) + context_str = "\n".join(f"{'>' if i == line_number else ' '} {i}: {line}" for i, line in context_lines) + error_line = linecache.getline("", line_number).strip() + caret_line = " " * (e.offset - 1) + "^" * (len(error_line) - e.offset + 1) + error_message = f"{error_class} at line {line_number}: {detail}\n {error_line}\n {caret_line}\n{context_str}" + raise InvalidUserCodeException(error_message) from e + + # =====[ All other Exceptions ]===== + except Exception as e: + error_class = e.__class__.__name__ + detail = str(e) + _, _, tb = sys.exc_info() + line_number = traceback.extract_tb(tb)[-1].lineno + context_lines = get_compilation_error_context("", line_number) + context_str = "\n".join(f"{'>' if i == line_number else ' '} {i}: {line}" for i, line in context_lines) + error_line = linecache.getline("", line_number).strip() + error_message = f"{error_class} at line {line_number}: {detail}\n {error_line}\n{context_str}" + raise InvalidUserCodeException(error_message) from e + + finally: + # Clear the cache to free up memory + linecache.clearcache() + + return custom_scope.get(func_name) diff --git a/src/codegen/utils/compilation/function_construction.py b/src/codegen/utils/compilation/function_construction.py new file mode 100644 index 000000000..6d199fec5 --- /dev/null +++ b/src/codegen/utils/compilation/function_construction.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import logging +import re + +from codegen.utils.compilation.function_imports import get_generated_imports + +logger = logging.getLogger(__name__) + + +def create_function_str_from_codeblock(codeblock: str, func_name: str) -> str: + """Creates a function string from a codeblock.""" + # =====[ Make an `execute` function string w/ imports ]===== + func_str = wrap_codeblock_in_function(codeblock, func_name) + + # =====[ Add imports to the top ]===== + func_str = get_imports_string().format(func_str=func_str) + return func_str + + +def wrap_codeblock_in_function(codeblock: str, func_name: str) -> str: + """Wrap a codeblock in a function with the specified name. + + Args: + codeblock (str): The code to be wrapped in a function. + func_name (str): The name to give the wrapping function. + + Note: + Skip wrapping if a function with the specified name already exists in the codeblock. + """ + if re.search(rf"\bdef\s+{func_name}\s*\(", codeblock): + logger.info(f"Codeblock already has a function named {func_name}. Skipping wrap.") + return codeblock + + # If not function func_name does not already exist, create a new function with the codeblock inside + user_code = indent_user_code(codeblock) + codeblock = f""" +def {func_name}(codebase: Codebase, pr_options: PROptions | None = None, pr = None, **kwargs): + print = codebase.log +{user_code} + """ + return codeblock + + +def indent_user_code(codeblock: str) -> str: + return "\n".join(f" {line}" for line in codeblock.strip().split("\n")) + + +def get_imports_string(): + """Gets imports marked with apidoc decorators. This list is autogenerated by generate_runner_imports""" + imports_str = get_generated_imports() + + func_str_template = """ + +{func_str} +""" + return imports_str + func_str_template diff --git a/src/codegen/utils/compilation/function_imports.py b/src/codegen/utils/compilation/function_imports.py new file mode 100644 index 000000000..5937d906e --- /dev/null +++ b/src/codegen/utils/compilation/function_imports.py @@ -0,0 +1,196 @@ +# This file is auto-generated, do not modify manually. Edit this in codegen-backend/cli/generate/runner_imports.py. +def get_generated_imports(): + return """ +# External imports +import os +import re +from pathlib import Path +import networkx as nx +import plotly + +# GraphSitter imports (private) + +from codegen.git.models.codemod_context import CodemodContext +from codegen.git.models.github_named_user_context import GithubNamedUserContext +from codegen.git.models.pr_part_context import PRPartContext +from codegen.git.models.pull_request_context import PullRequestContext + +from codegen.sdk.codebase.control_flow import StopCodemodException + +# GraphSitter imports (public) +from codegen.sdk.codebase.flagging.enums import FlagKwargs +from codegen.sdk.codebase.flagging.enums import MessageType +from codegen.sdk.codebase.span import Span +from codegen.sdk.core.assignment import Assignment +from codegen.sdk.core.class_definition import Class +from codegen.sdk.core.codebase import Codebase +from codegen.sdk.core.codebase import CodebaseType +from codegen.sdk.core.codebase import PyCodebaseType +from codegen.sdk.core.codebase import TSCodebaseType +from codegen.sdk.core.dataclasses.usage import Usage +from codegen.sdk.core.dataclasses.usage import UsageKind +from codegen.sdk.core.dataclasses.usage import UsageType +from codegen.sdk.core.detached_symbols.argument import Argument +from codegen.sdk.core.detached_symbols.code_block import CodeBlock +from codegen.sdk.core.detached_symbols.decorator import Decorator +from codegen.sdk.core.detached_symbols.function_call import FunctionCall +from codegen.sdk.core.detached_symbols.parameter import Parameter +from codegen.sdk.core.directory import Directory +from codegen.sdk.core.export import Export +from codegen.sdk.core.expressions.await_expression import AwaitExpression +from codegen.sdk.core.expressions.binary_expression import BinaryExpression +from codegen.sdk.core.expressions.boolean import Boolean +from codegen.sdk.core.expressions.chained_attribute import ChainedAttribute +from codegen.sdk.core.expressions.comparison_expression import ComparisonExpression +from codegen.sdk.core.expressions.expression import Expression +from codegen.sdk.core.expressions.generic_type import GenericType +from codegen.sdk.core.expressions.multi_expression import MultiExpression +from codegen.sdk.core.expressions.name import Name +from codegen.sdk.core.expressions.named_type import NamedType +from codegen.sdk.core.expressions.none_type import NoneType +from codegen.sdk.core.expressions.number import Number +from codegen.sdk.core.expressions.parenthesized_expression import ParenthesizedExpression +from codegen.sdk.core.expressions.placeholder_type import PlaceholderType +from codegen.sdk.core.expressions.string import String +from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression +from codegen.sdk.core.expressions.ternary_expression import TernaryExpression +from codegen.sdk.core.expressions.tuple_type import TupleType +from codegen.sdk.core.expressions.type import Type +from codegen.sdk.core.expressions.unary_expression import UnaryExpression +from codegen.sdk.core.expressions.union_type import UnionType +from codegen.sdk.core.expressions.unpack import Unpack +from codegen.sdk.core.expressions.value import Value +from codegen.sdk.core.external_module import ExternalModule +from codegen.sdk.core.file import File +from codegen.sdk.core.file import SourceFile +from codegen.sdk.core.function import Function +from codegen.sdk.core.import_resolution import Import +from codegen.sdk.core.interfaces.callable import Callable +from codegen.sdk.core.interfaces.editable import Editable +from codegen.sdk.core.interfaces.exportable import Exportable +from codegen.sdk.core.interfaces.has_block import HasBlock +from codegen.sdk.core.interfaces.has_name import HasName +from codegen.sdk.core.interfaces.has_value import HasValue +from codegen.sdk.core.interfaces.importable import Importable +from codegen.sdk.core.interfaces.typeable import Typeable +from codegen.sdk.core.interfaces.unwrappable import Unwrappable +from codegen.sdk.core.interfaces.usable import Usable +from codegen.sdk.core.placeholder.placeholder import Placeholder +from codegen.sdk.core.placeholder.placeholder_stub import StubPlaceholder +from codegen.sdk.core.placeholder.placeholder_type import TypePlaceholder +from codegen.sdk.core.statements.assignment_statement import AssignmentStatement +from codegen.sdk.core.statements.attribute import Attribute +from codegen.sdk.core.statements.block_statement import BlockStatement +from codegen.sdk.core.statements.catch_statement import CatchStatement +from codegen.sdk.core.statements.comment import Comment +from codegen.sdk.core.statements.export_statement import ExportStatement +from codegen.sdk.core.statements.expression_statement import ExpressionStatement +from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement +from codegen.sdk.core.statements.if_block_statement import IfBlockStatement +from codegen.sdk.core.statements.import_statement import ImportStatement +from codegen.sdk.core.statements.raise_statement import RaiseStatement +from codegen.sdk.core.statements.return_statement import ReturnStatement +from codegen.sdk.core.statements.statement import Statement +from codegen.sdk.core.statements.statement import StatementType +from codegen.sdk.core.statements.switch_case import SwitchCase +from codegen.sdk.core.statements.switch_statement import SwitchStatement +from codegen.sdk.core.statements.symbol_statement import SymbolStatement +from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement +from codegen.sdk.core.statements.while_statement import WhileStatement +from codegen.sdk.core.symbol import Symbol +from codegen.sdk.core.symbol_group import SymbolGroup +from codegen.sdk.core.symbol_groups.comment_group import CommentGroup +from codegen.sdk.core.symbol_groups.dict import Dict +from codegen.sdk.core.symbol_groups.dict import Pair +from codegen.sdk.core.symbol_groups.expression_group import ExpressionGroup +from codegen.sdk.core.symbol_groups.list import List +from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection +from codegen.sdk.core.symbol_groups.tuple import Tuple +from codegen.sdk.core.type_alias import TypeAlias +from codegen.sdk.python.assignment import PyAssignment +from codegen.sdk.python.class_definition import PyClass +from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock +from codegen.sdk.python.detached_symbols.decorator import PyDecorator +from codegen.sdk.python.detached_symbols.parameter import PyParameter +from codegen.sdk.python.expressions.chained_attribute import PyChainedAttribute +from codegen.sdk.python.expressions.conditional_expression import PyConditionalExpression +from codegen.sdk.python.expressions.generic_type import PyGenericType +from codegen.sdk.python.expressions.named_type import PyNamedType +from codegen.sdk.python.expressions.string import PyString +from codegen.sdk.python.expressions.union_type import PyUnionType +from codegen.sdk.python.file import PyFile +from codegen.sdk.python.function import PyFunction +from codegen.sdk.python.import_resolution import PyImport +from codegen.sdk.python.interfaces.has_block import PyHasBlock +from codegen.sdk.python.placeholder.placeholder_return_type import PyReturnTypePlaceholder +from codegen.sdk.python.statements.assignment_statement import PyAssignmentStatement +from codegen.sdk.python.statements.attribute import PyAttribute +from codegen.sdk.python.statements.block_statement import PyBlockStatement +from codegen.sdk.python.statements.break_statement import PyBreakStatement +from codegen.sdk.python.statements.catch_statement import PyCatchStatement +from codegen.sdk.python.statements.comment import PyComment +from codegen.sdk.python.statements.comment import PyCommentType +from codegen.sdk.python.statements.for_loop_statement import PyForLoopStatement +from codegen.sdk.python.statements.if_block_statement import PyIfBlockStatement +from codegen.sdk.python.statements.import_statement import PyImportStatement +from codegen.sdk.python.statements.match_case import PyMatchCase +from codegen.sdk.python.statements.match_statement import PyMatchStatement +from codegen.sdk.python.statements.pass_statement import PyPassStatement +from codegen.sdk.python.statements.try_catch_statement import PyTryCatchStatement +from codegen.sdk.python.statements.while_statement import PyWhileStatement +from codegen.sdk.python.statements.with_statement import WithStatement +from codegen.sdk.python.symbol import PySymbol +from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup +from codegen.sdk.typescript.assignment import TSAssignment +from codegen.sdk.typescript.class_definition import TSClass +from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock +from codegen.sdk.typescript.detached_symbols.decorator import TSDecorator +from codegen.sdk.typescript.detached_symbols.jsx.element import JSXElement +from codegen.sdk.typescript.detached_symbols.jsx.expression import JSXExpression +from codegen.sdk.typescript.detached_symbols.jsx.prop import JSXProp +from codegen.sdk.typescript.detached_symbols.parameter import TSParameter +from codegen.sdk.typescript.enum_definition import TSEnum +from codegen.sdk.typescript.export import TSExport +from codegen.sdk.typescript.expressions.array_type import TSArrayType +from codegen.sdk.typescript.expressions.chained_attribute import TSChainedAttribute +from codegen.sdk.typescript.expressions.conditional_type import TSConditionalType +from codegen.sdk.typescript.expressions.expression_type import TSExpressionType +from codegen.sdk.typescript.expressions.function_type import TSFunctionType +from codegen.sdk.typescript.expressions.generic_type import TSGenericType +from codegen.sdk.typescript.expressions.lookup_type import TSLookupType +from codegen.sdk.typescript.expressions.named_type import TSNamedType +from codegen.sdk.typescript.expressions.object_type import TSObjectType +from codegen.sdk.typescript.expressions.query_type import TSQueryType +from codegen.sdk.typescript.expressions.readonly_type import TSReadonlyType +from codegen.sdk.typescript.expressions.string import TSString +from codegen.sdk.typescript.expressions.ternary_expression import TSTernaryExpression +from codegen.sdk.typescript.expressions.undefined_type import TSUndefinedType +from codegen.sdk.typescript.expressions.union_type import TSUnionType +from codegen.sdk.typescript.file import TSFile +from codegen.sdk.typescript.function import TSFunction +from codegen.sdk.typescript.import_resolution import TSImport +from codegen.sdk.typescript.interface import TSInterface +from codegen.sdk.typescript.interfaces.has_block import TSHasBlock +from codegen.sdk.typescript.namespace import TSNamespace +from codegen.sdk.typescript.placeholder.placeholder_return_type import TSReturnTypePlaceholder +from codegen.sdk.typescript.statements.assignment_statement import TSAssignmentStatement +from codegen.sdk.typescript.statements.attribute import TSAttribute +from codegen.sdk.typescript.statements.block_statement import TSBlockStatement +from codegen.sdk.typescript.statements.catch_statement import TSCatchStatement +from codegen.sdk.typescript.statements.comment import TSComment +from codegen.sdk.typescript.statements.comment import TSCommentType +from codegen.sdk.typescript.statements.for_loop_statement import TSForLoopStatement +from codegen.sdk.typescript.statements.if_block_statement import TSIfBlockStatement +from codegen.sdk.typescript.statements.import_statement import TSImportStatement +from codegen.sdk.typescript.statements.labeled_statement import TSLabeledStatement +from codegen.sdk.typescript.statements.switch_case import TSSwitchCase +from codegen.sdk.typescript.statements.switch_statement import TSSwitchStatement +from codegen.sdk.typescript.statements.try_catch_statement import TSTryCatchStatement +from codegen.sdk.typescript.statements.while_statement import TSWhileStatement +from codegen.sdk.typescript.symbol import TSSymbol +from codegen.sdk.typescript.symbol_groups.comment_group import TSCommentGroup +from codegen.sdk.typescript.symbol_groups.dict import TSDict +from codegen.sdk.typescript.symbol_groups.dict import TSPair +from codegen.sdk.typescript.ts_config import TSConfig +from codegen.sdk.typescript.type_alias import TSTypeAlias +""" diff --git a/src/codegen/utils/compilation/string_to_code.py b/src/codegen/utils/compilation/string_to_code.py new file mode 100644 index 000000000..3cb585c89 --- /dev/null +++ b/src/codegen/utils/compilation/string_to_code.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import linecache +import logging +import sys +import traceback +from collections.abc import Callable +from typing import Any + +from codegen.utils.compilation.codeblock_validation import check_for_dangerous_operations +from codegen.utils.compilation.exception_utils import get_local_frame, get_offset_traceback +from codegen.utils.compilation.function_compilation import safe_compile_function_string +from codegen.utils.compilation.function_construction import create_function_str_from_codeblock, get_imports_string +from codegen.utils.exceptions.control_flow import StopCodemodException + +logger = logging.getLogger(__name__) + + +def create_execute_function_from_codeblock(codeblock: str, custom_scope: dict | None = None, func_name: str = "execute") -> Callable: + """Convert a user code string into a Callable that takes in a Codebase. + + Steps: + 1. Check for any dangerous operations in the codeblock. Will raise DangerousUserCodeException if any dangerous operations are found. + 2. Create a function string from the codeblock. Ex: "def execute(codebase: Codebase): ..." + 3. Compile the function string into a Callable that takes in a Codebase. Will raise InvalidUserCodeException if there are any code errors (ex: IndentationErrors) + 4. Wrap the function in another function (that also takes in a Codebase) that handles calling the function and safely handling any exceptions occur during execution. + + Args: + codeblock (str): The user code to construct the Callable with (usually CodemodVersionModel.source) + custom_scope (dict | None, optional): Custom scope to be used during compilation. Defaults to None. + func_name (str, optional): Name of the function to be created. Defaults to "execute". + + Returns: + Callable: def (codebase: Codebase) -> any | dict + + Raises: + UnsafeUserCodeException: If the user's code contains dangerous operations. + InvalidUserCodeException: If there are syntax errors in the provided code. + """ + # =====[ Set up custom scope ]===== + custom_scope = custom_scope or {} + logger.info(f"create_execute_function custom_scope: {custom_scope.keys()}") + + # =====[ Check for dangerous operations in the codeblock ]===== + check_for_dangerous_operations(codeblock) + # =====[ Create function string from codeblock ]===== + func_str = create_function_str_from_codeblock(codeblock, func_name) + # =====[ Compile the function string into a function ]===== + func = safe_compile_function_string(custom_scope=custom_scope, func_name=func_name, func_str=func_str) + + # =====[ Compute line offset of func_str ]===== + # This is to generate the a traceback with the correct line window + len_imports = len(get_imports_string().split("\n")) + len_func_str = 1 + line_offset = len_imports + len_func_str + + # =====[ Create closure function to enclose outer scope variables]===== + def closure_func() -> Callable[[Any], None]: + """Wrap user code in a closure to capture the outer scope variables and format errors.""" + _func_str = func_str + _line_offset = line_offset + + # Wrap the func for better tracing + def wrapped_func(*args, **kwargs): + """Wraps the user code to capture and format exceptions + grab locals""" + try: + linecache.cache[""] = (len(_func_str), None, _func_str.splitlines(True), "") + func(*args, **kwargs) + + # =====[ Grab locals during `StopCodemodException` ]===== + except StopCodemodException as e: + logger.info(f"Stopping codemod due to {e.__class__.__name__}: {e}") + raise e + + except Exception as e: + # =====[ Get offset, filtered traceback message ]===== + tb_lines = traceback.format_exception(type(e), e, e.__traceback__) + error_message = get_offset_traceback(tb_lines, _line_offset, filenameFilter="") + + # =====[ Find frame in user's code ]===== + exc_type, exc_value, exc_traceback = sys.exc_info() + frame = get_local_frame(exc_type, exc_value, exc_traceback) + # TODO: handle frame is None + line_num = frame.f_lineno + + # =====[ Get context lines ]===== + context_start = max(0, line_num - 3) + context_end = min(len(func_str.split("\n")), line_num + 2) + context_lines = func_str.split("\n")[context_start:context_end] + + # =====[ Format error message with context ]===== + error_lines = [] + for i, line in enumerate(context_lines, start=context_start + 1): + marker = ">" if i == line_num else " " + error_lines.append(f"{marker} {i - _line_offset}: {line.rstrip()}") + error_context = "\n".join(error_lines) + + # =====[ Format error message ]===== + error_message = ( + error_message + + f""" + +Code context: +{error_context} +""" + ) + raise RuntimeError(error_message) from e + + return wrapped_func + + return closure_func() diff --git a/src/codegen/utils/performance/stopwatch_utils.py b/src/codegen/utils/performance/stopwatch_utils.py index ac4b76e9f..87f2611ba 100644 --- a/src/codegen/utils/performance/stopwatch_utils.py +++ b/src/codegen/utils/performance/stopwatch_utils.py @@ -5,7 +5,7 @@ import sentry_sdk -from codegen.utils.time_utils import humanize_duration +from codegen.utils.performance.time_utils import humanize_duration logger = logging.getLogger(__name__) diff --git a/src/codegen/utils/time_utils.py b/src/codegen/utils/performance/time_utils.py similarity index 100% rename from src/codegen/utils/time_utils.py rename to src/codegen/utils/performance/time_utils.py diff --git a/src/codegen/utils/csv_utils.py b/src/codegen/utils/string/csv_utils.py similarity index 100% rename from src/codegen/utils/csv_utils.py rename to src/codegen/utils/string/csv_utils.py