Skip to content

Commit 22cdad9

Browse files
authored
Merge pull request packit#84 from nforro/workflows
BeeAI: implement a workflow for rebase and backport agents
2 parents 889d9f4 + d562e3c commit 22cdad9

File tree

9 files changed

+439
-340
lines changed

9 files changed

+439
-340
lines changed

beeai/Containerfile.mcp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ RUN useradd -m -G wheel mcp
2323
# although it is locally mounted through a volume
2424
COPY mcp_server/ /home/mcp/mcp_server/
2525
RUN chgrp -R root /home/mcp && chmod -R g+rwX /home/mcp
26+
RUN mkdir /git-repos && chmod -R o+rwX /git-repos
2627

2728
USER mcp
2829
WORKDIR /home/mcp

beeai/agents/backport_agent.py

Lines changed: 119 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import asyncio
22
import logging
33
import os
4-
from shutil import rmtree
5-
from pathlib import Path
64
import subprocess
75
import sys
86
import traceback
9-
from typing import Optional
7+
from pathlib import Path
108

119
from pydantic import BaseModel, Field
1210

@@ -22,45 +20,41 @@
2220
from beeai_framework.tools import Tool
2321
from beeai_framework.tools.search.duckduckgo import DuckDuckGoSearchTool
2422
from beeai_framework.tools.think import ThinkTool
23+
from beeai_framework.workflows import Workflow
2524

25+
import tasks
26+
from constants import COMMIT_PREFIX
27+
from observability import setup_observability
28+
from tools.commands import RunShellCommandTool
2629
from tools.specfile import AddChangelogEntryTool, BumpReleaseTool
2730
from tools.text import CreateTool, InsertTool, StrReplaceTool, ViewTool
2831
from tools.wicked_git import GitLogSearchTool, GitPatchCreationTool
29-
from constants import COMMIT_PREFIX, BRANCH_PREFIX
30-
from observability import setup_observability
31-
from tools.commands import RunShellCommandTool
3232
from triage_agent import BackportData, ErrorData
33-
from utils import get_agent_execution_config, mcp_tools, redis_client, get_git_finalization_steps
33+
from utils import check_subprocess, get_agent_execution_config, mcp_tools, redis_client
3434

3535
logger = logging.getLogger(__name__)
3636

3737

3838
class InputSchema(BaseModel):
39+
local_clone: Path = Field(description="Path to the local clone of forked dist-git repository")
40+
unpacked_sources: Path = Field(description="Path to the unpacked (using `centpkg prep`) sources")
3941
package: str = Field(description="Package to update")
42+
dist_git_branch: str = Field(description="Git branch in dist-git to be updated")
4043
upstream_fix: str = Field(description="Link to an upstream fix for the issue")
4144
jira_issue: str = Field(description="Jira issue to reference as resolved")
4245
cve_id: str = Field(default="", description="CVE ID if the jira issue is a CVE")
43-
dist_git_branch: str = Field(description="Git branch in dist-git to be updated")
44-
git_repo_basepath: str = Field(
45-
description="Base path for cloned git repos",
46-
default=os.getenv("GIT_REPO_BASEPATH"),
47-
)
48-
unpacked_sources: str = Field(
49-
description="Path to the unpacked (using `centpkg prep`) sources",
50-
default="",
51-
)
5246

5347

5448
class OutputSchema(BaseModel):
5549
success: bool = Field(description="Whether the backport was successfully completed")
5650
status: str = Field(description="Backport status")
57-
mr_url: Optional[str] = Field(description="URL to the opened merge request")
58-
error: Optional[str] = Field(description="Specific details about an error")
51+
mr_url: str | None = Field(description="URL to the opened merge request")
52+
error: str | None = Field(description="Specific details about an error")
5953

6054

6155
def render_prompt(input: InputSchema) -> str:
6256
template = (
63-
'Work inside the repository cloned at "{{ git_repo_basepath }}/{{ package }}"\n'
57+
'Work inside the repository cloned in "{{ local_clone }}", it is your current working directory\n'
6458
"Use the `git_log_search` tool to check if the jira issue ({{ jira_issue }}) or CVE ({{ cve_id }}) is already resolved.\n"
6559
"If the issue or the cve are already resolved, exit the backporting process with success=True and status=\"Backport already applied\"\n"
6660
"Download the upstream fix from {{ upstream_fix }}\n"
@@ -74,64 +68,8 @@ def render_prompt(input: InputSchema) -> str:
7468
"Delete all *.rej files\n"
7569
"DO **NOT** RUN COMMAND `git am --continue`\n"
7670
"Once you resolve all conflicts, use tool git_patch_create to create a patch file\n"
77-
"{{ backport_git_steps }}"
7871
)
79-
80-
# Define template function that can be called from the template
81-
def backport_git_steps(data):
82-
input_data = InputSchema.model_validate(data)
83-
return get_git_finalization_steps(
84-
package=input_data.package,
85-
jira_issue=input_data.jira_issue,
86-
commit_title=f"{COMMIT_PREFIX} backport {input_data.jira_issue}",
87-
files_to_commit=f"*.spec and {input_data.jira_issue}.patch",
88-
branch_name=f"{BRANCH_PREFIX}-{input_data.jira_issue}",
89-
dist_git_branch=input_data.dist_git_branch,
90-
)
91-
92-
return PromptTemplate(
93-
PromptTemplateInput(schema=InputSchema, template=template, functions={"backport_git_steps": backport_git_steps})
94-
).render(input)
95-
96-
97-
def prepare_package(
98-
package: str, jira_issue: str, dist_git_branch: str, input_schema: InputSchema
99-
) -> tuple[Path, Path]:
100-
"""
101-
Prepare the package for backporting by cloning the dist-git repository, switching to the appropriate branch,
102-
and downloading the sources.
103-
Returns the path to the unpacked sources.
104-
"""
105-
git_repo = Path(input_schema.git_repo_basepath)
106-
git_repo.mkdir(parents=True, exist_ok=True)
107-
subprocess.check_call(
108-
[
109-
"centpkg",
110-
"clone",
111-
"--anonymous",
112-
"--branch",
113-
dist_git_branch,
114-
package,
115-
],
116-
cwd=git_repo,
117-
)
118-
local_clone = git_repo / package
119-
subprocess.check_call(
120-
[
121-
"git",
122-
"switch",
123-
"-c",
124-
f"automated-package-update-{jira_issue}",
125-
dist_git_branch,
126-
],
127-
cwd=local_clone,
128-
)
129-
subprocess.check_call(["centpkg", "sources"], cwd=local_clone)
130-
subprocess.check_call(["centpkg", "prep"], cwd=local_clone)
131-
unpacked_sources = list(local_clone.glob(f"*-build/*{package}*"))
132-
if len(unpacked_sources) != 1:
133-
raise ValueError(f"Expected exactly one unpacked source, got {unpacked_sources}")
134-
return unpacked_sources[0], local_clone
72+
return PromptTemplate(PromptTemplateInput(schema=InputSchema, template=template)).render(input)
13573

13674

13775
async def main() -> None:
@@ -141,7 +79,7 @@ async def main() -> None:
14179
cve_id = os.getenv("CVE_ID", "")
14280

14381
async with mcp_tools(os.getenv("MCP_GATEWAY_URL")) as gateway_tools:
144-
agent = RequirementAgent(
82+
backport_agent = RequirementAgent(
14583
llm=ChatModel.from_name(os.getenv("CHAT_MODEL")),
14684
tools=[
14785
ThinkTool(),
@@ -155,11 +93,6 @@ async def main() -> None:
15593
GitLogSearchTool(),
15694
BumpReleaseTool(),
15795
AddChangelogEntryTool(),
158-
]
159-
+ [
160-
t
161-
for t in gateway_tools
162-
if t.name in ("fork_repository", "open_merge_request", "push_to_remote_repository")
16396
],
16497
memory=UnconstrainedMemory(),
16598
requirements=[
@@ -182,41 +115,110 @@ async def main() -> None:
182115
],
183116
)
184117

185-
dry_run = os.getenv("DRY_RUN", "False").lower() == "true"
118+
class State(BaseModel):
119+
jira_issue: str
120+
package: str
121+
dist_git_branch: str
122+
upstream_fix: str
123+
cve_id: str
124+
local_clone: Path | None = Field(default=None)
125+
update_branch: str | None = Field(default=None)
126+
fork_url: str | None = Field(default=None)
127+
unpacked_sources: Path | None = Field(default=None)
128+
backport_result: OutputSchema | None = Field(default=None)
129+
merge_request_url: str | None = Field(default=None)
130+
131+
workflow = Workflow(State)
186132

187-
async def run(input):
188-
response = await agent.run(
189-
prompt=render_prompt(input),
190-
expected_output=OutputSchema,
191-
execution=get_agent_execution_config(),
133+
async def fork_and_prepare_dist_git(state):
134+
state.local_clone, state.update_branch, state.fork_url = await tasks.fork_and_prepare_dist_git(
135+
jira_issue=state.jira_issue,
136+
package=state.package,
137+
dist_git_branch=state.dist_git_branch,
138+
available_tools=gateway_tools,
192139
)
193-
return OutputSchema.model_validate_json(response.answer.text)
140+
await check_subprocess(["centpkg", "sources"], cwd=state.local_clone)
141+
await check_subprocess(["centpkg", "prep"], cwd=state.local_clone)
142+
unpacked_sources = list(state.local_clone.glob(f"*-build/*{state.package}*"))
143+
if len(unpacked_sources) != 1:
144+
raise ValueError(f"Expected exactly one unpacked source, got {unpacked_sources}")
145+
[state.unpacked_sources] = unpacked_sources
146+
return "run_backport_agent"
147+
148+
async def run_backport_agent(state):
149+
cwd = Path.cwd()
150+
try:
151+
# make things easier for the LLM
152+
os.chdir(state.local_clone)
153+
response = await backport_agent.run(
154+
prompt=render_prompt(
155+
InputSchema(
156+
local_clone=state.local_clone,
157+
unpacked_sources=state.unpacked_sources,
158+
package=state.package,
159+
dist_git_branch=state.dist_git_branch,
160+
upstream_fix=state.upstream_fix,
161+
jira_issue=state.jira_issue,
162+
cve_id=state.cve_id,
163+
),
164+
),
165+
expected_output=OutputSchema,
166+
execution=get_agent_execution_config(),
167+
)
168+
state.backport_result = OutputSchema.model_validate_json(response.answer.text)
169+
finally:
170+
os.chdir(cwd)
171+
if state.backport_result.success:
172+
return "commit_push_and_open_mr"
173+
else:
174+
return Workflow.END
175+
176+
async def commit_push_and_open_mr(state):
177+
state.merge_request_url = await tasks.commit_push_and_open_mr(
178+
local_clone=state.local_clone,
179+
files_to_commit=["*.spec", f"{state.jira_issue}.patch"],
180+
commit_message=f"{COMMIT_PREFIX} backport {state.jira_issue}",
181+
fork_url=state.fork_url,
182+
dist_git_branch=state.dist_git_branch,
183+
update_branch=state.update_branch,
184+
mr_title="{COMMIT_PREFIX} backport {state.jira_issue}",
185+
mr_description="TODO",
186+
available_tools=gateway_tools,
187+
commit_only=os.getenv("DRY_RUN", "False").lower() == "true",
188+
)
189+
return Workflow.END
190+
191+
workflow.add_step("fork_and_prepare_dist_git", fork_and_prepare_dist_git)
192+
workflow.add_step("run_backport_agent", run_backport_agent)
193+
workflow.add_step("commit_push_and_open_mr", commit_push_and_open_mr)
194+
195+
async def run_workflow(package, dist_git_branch, upstream_fix, jira_issue, cve_id):
196+
response = await workflow.run(
197+
State(
198+
package=package,
199+
dist_git_branch=dist_git_branch,
200+
upstream_fix=upstream_fix,
201+
jira_issue=jira_issue,
202+
cve_id=cve_id,
203+
),
204+
)
205+
return response.state
194206

195207
if (
196208
(package := os.getenv("PACKAGE", None))
209+
and (branch := os.getenv("BRANCH", None))
197210
and (upstream_fix := os.getenv("UPSTREAM_FIX", None))
198211
and (jira_issue := os.getenv("JIRA_ISSUE", None))
199-
and (branch := os.getenv("BRANCH", None))
200212
):
201213
logger.info("Running in direct mode with environment variables")
202-
input = InputSchema(
214+
state = await run_workflow(
203215
package=package,
216+
dist_git_branch=branch,
204217
upstream_fix=upstream_fix,
205218
jira_issue=jira_issue,
206-
dist_git_branch=branch,
207-
cve_id=cve_id,
219+
cve_id=os.getenv("CVE_ID", ""),
208220
)
209-
unpacked_sources, local_clone = prepare_package(package, jira_issue, branch, input)
210-
input.unpacked_sources = str(unpacked_sources)
211-
try:
212-
output = await run(input)
213-
finally:
214-
if not dry_run:
215-
logger.info(f"Removing {local_clone}")
216-
rmtree(local_clone)
217-
else:
218-
logger.info(f"DRY RUN: Not removing {local_clone}")
219-
logger.info(f"Direct run completed: {output.model_dump_json(indent=4)}")
221+
logger.info(f"Direct run completed: {state.backport_result.model_dump_json(indent=4)}")
220222
return
221223

222224
class Task(BaseModel):
@@ -245,18 +247,6 @@ class Task(BaseModel):
245247
f"JIRA: {backport_data.jira_issue}, attempt: {task.attempts + 1}"
246248
)
247249

248-
input = InputSchema(
249-
package=backport_data.package,
250-
upstream_fix=backport_data.patch_url,
251-
jira_issue=backport_data.jira_issue,
252-
dist_git_branch=backport_data.branch,
253-
cve_id=backport_data.cve_id,
254-
)
255-
unpacked_sources, local_clone = prepare_package(
256-
backport_data.package, backport_data.jira_issue, backport_data.branch, input
257-
)
258-
input.unpacked_sources = str(unpacked_sources)
259-
260250
async def retry(task, error):
261251
task.attempts += 1
262252
if task.attempts < max_retries:
@@ -274,23 +264,29 @@ async def retry(task, error):
274264

275265
try:
276266
logger.info(f"Starting backport processing for {backport_data.jira_issue}")
277-
output = await run(input)
267+
state = await run_workflow(
268+
package=backport_data.package,
269+
dist_git_branch=backport_data.branch,
270+
upstream_fix=backport_data.patch_url,
271+
jira_issue=backport_data.jira_issue,
272+
cve_id=backport_data.cve_id,
273+
)
278274
logger.info(
279-
f"Backport processing completed for {backport_data.jira_issue}, " f"success: {output.success}"
275+
f"Backport processing completed for {backport_data.jira_issue}, " f"success: {state.backport_result.success}"
280276
)
281277
except Exception as e:
282278
error = "".join(traceback.format_exception(e))
283279
logger.error(f"Exception during backport processing for {backport_data.jira_issue}: {error}")
284-
await retry(task, ErrorData(details=error, jira_issue=input.jira_issue).model_dump_json())
280+
await retry(task, ErrorData(details=error, jira_issue=backport_data.jira_issue).model_dump_json())
285281
rmtree(local_clone)
286282
else:
287283
rmtree(local_clone)
288-
if output.success:
284+
if state.backport_data.success:
289285
logger.info(f"Backport successful for {backport_data.jira_issue}, " f"adding to completed list")
290-
await redis.lpush("completed_backport_list", output.model_dump_json())
286+
await redis.lpush("completed_backport_list", state.backport_data.model_dump_json())
291287
else:
292-
logger.warning(f"Backport failed for {backport_data.jira_issue}: {output.error}")
293-
await retry(task, output.error)
288+
logger.warning(f"Backport failed for {backport_data.jira_issue}: {state.backport_data.error}")
289+
await retry(task, state.backport_data.error)
294290

295291

296292
if __name__ == "__main__":

0 commit comments

Comments
 (0)