Skip to content

Commit 511c633

Browse files
committed
update tests and restore sandbox tests
1 parent cb26533 commit 511c633

File tree

12 files changed

+551
-243
lines changed

12 files changed

+551
-243
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from collections.abc import Generator
2+
from unittest.mock import Mock
3+
4+
import pytest
5+
6+
from graph_sitter.git.clients.git_repo_client import GitRepoClient
7+
from graph_sitter.git.repo_operator.repo_operator import RepoOperator
8+
from graph_sitter.git.schemas.enums import SetupOption
9+
from graph_sitter.git.schemas.repo_config import RepoConfig
10+
from graph_sitter.runner.clients.codebase_client import CodebaseClient
11+
from graph_sitter.shared.enums.programming_language import ProgrammingLanguage
12+
from graph_sitter.shared.network.port import get_free_port
13+
14+
15+
@pytest.fixture()
16+
def repo_config(tmpdir) -> Generator[RepoConfig, None, None]:
17+
yield RepoConfig(
18+
name="Kevin-s-Adventure-Game",
19+
full_name="codegen-sh/Kevin-s-Adventure-Game",
20+
language=ProgrammingLanguage.PYTHON,
21+
base_dir=str(tmpdir),
22+
)
23+
24+
25+
@pytest.fixture
26+
def op(repo_config: RepoConfig) -> Generator[RepoOperator, None, None]:
27+
yield RepoOperator(repo_config=repo_config, setup_option=SetupOption.PULL_OR_CLONE)
28+
29+
30+
@pytest.fixture
31+
def git_repo_client(op: RepoOperator, repo_config: RepoConfig) -> Generator[GitRepoClient, None, None]:
32+
yield GitRepoClient(repo_config=repo_config, access_token=op.access_token)
33+
34+
35+
@pytest.fixture
36+
def codebase_client(repo_config: RepoConfig) -> Generator[CodebaseClient, None, None]:
37+
sb_client = CodebaseClient(repo_config=repo_config, port=get_free_port())
38+
sb_client.runner = Mock()
39+
yield sb_client
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import uuid
2+
from http import HTTPStatus
3+
4+
import pytest
5+
6+
from graph_sitter.git.clients.git_repo_client import GitRepoClient
7+
from graph_sitter.git.repo_operator.repo_operator import RepoOperator
8+
from graph_sitter.runner.clients.codebase_client import CodebaseClient
9+
from graph_sitter.runner.models.apis import BRANCH_ENDPOINT, CreateBranchRequest, CreateBranchResponse
10+
from graph_sitter.runner.models.codemod import BranchConfig, Codemod, GroupingConfig
11+
12+
13+
@pytest.mark.asyncio
14+
@pytest.mark.timeout(60)
15+
async def test_create_branch(codebase_client: CodebaseClient, git_repo_client: GitRepoClient, op: RepoOperator):
16+
# set-up
17+
codemod_source = """
18+
for file in codebase.files:
19+
new_content = "🌈" + "\\n" + file.content
20+
file.edit(new_content)
21+
break
22+
"""
23+
test_branch_name = f"codegen-test-create-branch-{uuid.uuid1()}"
24+
request = CreateBranchRequest(
25+
codemod=Codemod(user_code=codemod_source),
26+
commit_msg="Create branch test",
27+
grouping_config=GroupingConfig(),
28+
branch_config=BranchConfig(branch_name=test_branch_name),
29+
)
30+
31+
# execute
32+
response = codebase_client.post(endpoint=BRANCH_ENDPOINT, data=request.model_dump())
33+
assert response.status_code == HTTPStatus.OK
34+
35+
# verify
36+
result = CreateBranchResponse.model_validate(response.json())
37+
assert len(result.results) == 1
38+
assert result.results[0].is_complete
39+
assert result.results[0].error is None
40+
assert result.results[0].logs == ""
41+
assert result.results[0].observation is not None
42+
43+
# verify changed files
44+
patch = result.results[0].observation
45+
lines = patch.split("\n")
46+
added_lines = [line[1:] for line in lines if line.startswith("+") and len(line) > 1]
47+
assert "🌈" in added_lines
48+
49+
# verify returned branch
50+
assert len(result.branches) == 1
51+
branch = result.branches[0]
52+
assert branch.base_branch == "main"
53+
assert branch.head_ref == test_branch_name
54+
55+
# verify remote branch
56+
remote_branch = git_repo_client.repo.get_branch(test_branch_name)
57+
assert remote_branch is not None
58+
assert remote_branch.name == test_branch_name
59+
assert remote_branch.commit.commit.message == "[Codegen] Create branch test"
60+
61+
# clean-up
62+
remote = op.git_cli.remote(name="origin")
63+
remote.push([f":refs/heads/{test_branch_name}"]) # The colon prefix means delete
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import uuid
2+
from http import HTTPStatus
3+
4+
import pytest
5+
6+
from graph_sitter.codebase.flagging.groupers.enums import GroupBy
7+
from graph_sitter.git.clients.git_repo_client import GitRepoClient
8+
from graph_sitter.git.repo_operator.repo_operator import RepoOperator
9+
from graph_sitter.runner.clients.codebase_client import CodebaseClient
10+
from graph_sitter.runner.models.apis import BRANCH_ENDPOINT, CreateBranchRequest, CreateBranchResponse
11+
from graph_sitter.runner.models.codemod import BranchConfig, Codemod, GroupingConfig
12+
13+
14+
@pytest.mark.timeout(120)
15+
@pytest.mark.parametrize("group_by", [GroupBy.INSTANCE, GroupBy.FILE])
16+
def test_create_branch_with_grouping(codebase_client: CodebaseClient, git_repo_client: GitRepoClient, op: RepoOperator, group_by: GroupBy):
17+
codemod_source = """
18+
for file in codebase.files[:5]:
19+
flag = codebase.flag_instance(file)
20+
if codebase.should_fix(flag):
21+
new_content = "🌈" + "\\n" + file.content
22+
file.edit(new_content)
23+
"""
24+
commit_msg = "Create branch with grouping test"
25+
test_branch_name = f"codegen-{uuid.uuid1()}"
26+
request = CreateBranchRequest(
27+
codemod=Codemod(user_code=codemod_source),
28+
commit_msg=commit_msg,
29+
grouping_config=GroupingConfig(group_by=group_by),
30+
branch_config=BranchConfig(branch_name=test_branch_name),
31+
)
32+
33+
# execute
34+
response = codebase_client.post(endpoint=BRANCH_ENDPOINT, data=request.model_dump())
35+
assert response.status_code == HTTPStatus.OK
36+
37+
# verify
38+
result = CreateBranchResponse.model_validate(response.json())
39+
assert len(result.results) == 5
40+
assert len(result.branches) == 5
41+
42+
for i, branch in enumerate(result.branches):
43+
actual_branch_suffix = "-".join(branch.head_ref.split("-")[-2:])
44+
expected_branch_suffix = f"group-{i}"
45+
assert expected_branch_suffix == actual_branch_suffix
46+
47+
remote_branch = git_repo_client.repo.get_branch(branch.head_ref)
48+
assert remote_branch is not None
49+
assert remote_branch.name == branch.head_ref
50+
assert remote_branch.commit.commit.message == f"[Codegen] {commit_msg}"
51+
assert remote_branch.commit.commit.author.name == "codegen-sh[bot]"
52+
53+
comparison = git_repo_client.repo.compare(base=branch.base_branch, head=branch.head_ref)
54+
assert "+🌈" in comparison.files[0].patch
55+
56+
# clean-up
57+
remote = op.git_cli.remote(name="origin")
58+
remote.push([f":refs/heads/{branch.head_ref}"])
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from collections.abc import Generator
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
from graph_sitter.codebase.config import ProjectConfig
7+
from graph_sitter.core.codebase import Codebase
8+
from graph_sitter.git.repo_operator.repo_operator import RepoOperator
9+
from graph_sitter.runner.sandbox.executor import SandboxExecutor
10+
from graph_sitter.runner.sandbox.runner import SandboxRunner
11+
from graph_sitter.shared.enums.programming_language import ProgrammingLanguage
12+
13+
14+
@pytest.fixture
15+
def codebase(tmpdir) -> Codebase:
16+
op = RepoOperator.create_from_files(repo_path=f"{tmpdir}/test-repo", files={"test.py": "a = 1"}, bot_commit=True)
17+
projects = [ProjectConfig(repo_operator=op, programming_language=ProgrammingLanguage.PYTHON)]
18+
codebase = Codebase(projects=projects)
19+
return codebase
20+
21+
22+
@pytest.fixture
23+
def executor(codebase: Codebase) -> Generator[SandboxExecutor]:
24+
yield SandboxExecutor(codebase)
25+
26+
27+
@pytest.fixture
28+
def runner(codebase: Codebase, tmpdir):
29+
with patch("graph_sitter.runner.sandbox.runner.RepoOperator") as mock_op:
30+
with patch.object(SandboxRunner, "_build_graph") as mock_init_codebase:
31+
mock_init_codebase.return_value = codebase
32+
mock_op.return_value = codebase.op
33+
34+
yield SandboxRunner(repo_config=codebase.op.repo_config)

0 commit comments

Comments
 (0)