Skip to content

Commit 0cafa43

Browse files
committed
Implement a workflow for the rebase agent
Signed-off-by: Nikola Forró <[email protected]>
1 parent 889d9f4 commit 0cafa43

File tree

6 files changed

+317
-197
lines changed

6 files changed

+317
-197
lines changed

beeai/agents/rebase_agent.py

Lines changed: 141 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import sys
55
import traceback
6-
from typing import Optional
6+
from pathlib import Path
77

88
from pydantic import BaseModel, Field
99

@@ -19,123 +19,82 @@
1919
from beeai_framework.tools import Tool
2020
from beeai_framework.tools.search.duckduckgo import DuckDuckGoSearchTool
2121
from beeai_framework.tools.think import ThinkTool
22+
from beeai_framework.workflows import Workflow
2223

23-
from constants import COMMIT_PREFIX, BRANCH_PREFIX
24+
import tasks
25+
from constants import COMMIT_PREFIX
2426
from observability import setup_observability
2527
from tools.commands import RunShellCommandTool
28+
from tools.specfile import AddChangelogEntryTool
29+
from tools.text import CreateTool, InsertTool, StrReplaceTool, ViewTool
2630
from triage_agent import RebaseData, ErrorData
27-
from utils import get_agent_execution_config, mcp_tools, redis_client, get_git_finalization_steps
31+
from utils import get_agent_execution_config, mcp_tools, redis_client, run_tool
2832

2933
logger = logging.getLogger(__name__)
3034

3135

3236
class InputSchema(BaseModel):
37+
local_clone: Path = Field(description="Path to the local clone of forked dist-git repository")
3338
package: str = Field(description="Package to update")
39+
dist_git_branch: str = Field(description="dist-git branch to update")
3440
version: str = Field(description="Version to update to")
3541
jira_issue: str = Field(description="Jira issue to reference as resolved")
36-
dist_git_branch: str = Field(description="Git branch in dist-git to be updated")
37-
gitlab_user: str = Field(
38-
description="Name of the GitLab user",
39-
default=os.getenv("GITLAB_USER", "rhel-packaging-agent"),
40-
)
41-
git_url: str = Field(
42-
description="URL of the git repository",
43-
default="https://gitlab.com/redhat/centos-stream/rpms",
44-
)
45-
git_repo_basepath: str = Field(
46-
description="Base path for cloned git repos",
47-
default=os.getenv("GIT_REPO_BASEPATH"),
48-
)
4942

5043

5144
class OutputSchema(BaseModel):
5245
success: bool = Field(description="Whether the rebase was successfully completed")
5346
status: str = Field(description="Rebase status")
54-
mr_url: Optional[str] = Field(description="URL to the opened merge request")
55-
error: Optional[str] = Field(description="Specific details about an error")
47+
mr_url: str | None = Field(description="URL to the opened merge request")
48+
error: str | None = Field(description="Specific details about an error")
5649

5750

5851
def render_prompt(input: InputSchema) -> str:
5952
template = """
60-
You are an AI Agent tasked to rebase a CentOS package to a newer version following the exact workflow.
53+
You are an AI Agent tasked to rebase a package to a newer version following the exact workflow.
6154
6255
A couple of rules that you must follow and useful information for you:
63-
* All packages are in separate Git repositories under the Gitlab project {{ git_url }}
64-
* You can find the package at {{ git_url }}/{{ package }}
65-
* Use {{ gitlab_user }} as the GitLab user.
66-
* Work only in a temporary directory that you can create with the mktemp tool.
67-
* You can find packaging guidelines at https://docs.fedoraproject.org/en-US/packaging-guidelines/
56+
* You can find packaging guidelines at https://docs.fedoraproject.org/en-US/packaging-guidelines/.
6857
* You can find the RPM packaging guide at https://rpm-packaging-guide.github.io/.
69-
* Do not run the `centpkg new-sources` command for now (testing purposes), just write down the commands you would run.
70-
71-
IMPORTANT GUIDELINES:
72-
- **Tool Usage**: You have run_shell_command tool available - use it directly!
73-
- **Command Execution Rules**:
74-
- Use run_shell_command tool for ALL command execution
75-
- If a command shows "no output" or empty STDOUT, that is a VALID result - do not retry
76-
- Commands that succeed with no output are normal - report success
77-
- **Git Configuration**: Always configure git user name and email before any git operations
58+
* IMPORTANT: Do not run the `centpkg new-sources` command for now (testing purposes), just write down
59+
the commands you would run.
7860
7961
Follow exactly these steps:
8062
81-
1. Find the location of the {{ package }} package at {{ git_url }}. Always use the {{ dist_git_branch }} branch.
63+
1. You will find the cloned dist-git repository of the {{ package }} package in {{ local_clone }}.
64+
It is your current working directory, do not `cd` anywhere else.
8265
83-
2. Check if the {{ package }} was not already updated to version {{ version }}. That means comparing
84-
the current version and provided version.
85-
* The current version of the package can be found in the 'Version' field of the RPM .spec file.
66+
2. Check if the {{ package }} was not already updated to version {{ version }}. That means comparing
67+
the current version with the provided version.
68+
* The current version of the package can be found in the 'Version' field of the spec file.
8669
* If there is nothing to update, print a message and exit. Otherwise follow the instructions below.
87-
* Do not clone any repository for detecting the version in .spec file.
88-
89-
3. Create a local Git repository by following these steps:
90-
* Create a fork of the {{ package }} package using the `fork_repository` tool.
91-
* Clone the fork using git and HTTPS into a temporary directory under {{ git_repo_basepath }}.
9270
93-
4. Update the {{ package }} to the newer version:
94-
* Create a new Git branch named `automated-package-update-{{ version }}`.
71+
3. Update the {{ package }} to the newer version:
9572
* Update the local package by:
96-
* Updating the 'Version' and 'Release' fields in the .spec file as needed (or corresponding macros),
73+
* Updating the 'Version' and 'Release' fields (or corresponding macros) in the spec file as needed,
9774
following packaging documentation.
98-
* Make sure the format of the .spec file remains the same.
99-
* Updating macros related to update (e.g., 'commit') if present and necessary; examine the file's history
75+
* Make sure the format of the spec file remains the same.
76+
* Updating macros related to update (e.g., 'commit') if present and necessary; examine the file history
10077
to see how updates are typically done.
10178
* You might need to check some information in upstream repository, e.g. the commit SHA of the new version.
10279
* Creating a changelog entry, referencing the Jira issue as "Resolves: {{ jira_issue }}".
10380
* Downloading sources using `spectool -g -S {{ package }}.spec` (you might need to copy local sources,
104-
e.g. if the .spec file loads some macros from them, to a directory where spectool expects them).
81+
e.g. if the spec file loads some macros from them, to a directory where `spectool` expects them).
10582
* Uploading the new sources using `centpkg --release {{ dist_git_branch }} new-sources`.
10683
* IMPORTANT: Only performing changes relevant to the version update: Do not rename variables,
107-
comment out existing lines, or alter if-else branches in the .spec file.
84+
comment out existing lines, or alter if-else branches in the spec file.
10885
109-
5. Verify and adjust the changes:
110-
* Use `rpmlint` to validate your .spec file changes and fix any new errors it identifies.
111-
* Generate the SRPM using `rpmbuild -bs` (ensure your .spec file and source files are correctly
86+
4. Verify and adjust the changes:
87+
* Use `rpmlint` to validate your spec file changes and fix any new errors it identifies.
88+
* Generate the SRPM using `rpmbuild -bs` (ensure your spec file and source files are correctly
11289
copied to the build environment as required by the command).
11390
114-
6. {{ rebase_git_steps }}
115-
11691
Report the status of the rebase operation including:
11792
- Whether the package was already up to date
11893
- Any errors encountered during the process
11994
- The URL of the created merge request if successful
12095
- Any validation issues found with rpmlint
12196
"""
122-
123-
# Define template function that can be called from the template
124-
def rebase_git_steps(data: dict) -> str:
125-
input_data = InputSchema.model_validate(data)
126-
return get_git_finalization_steps(
127-
package=input_data.package,
128-
jira_issue=input_data.jira_issue,
129-
commit_title=f"{COMMIT_PREFIX} Update to version {input_data.version}",
130-
files_to_commit="*.spec",
131-
branch_name=f"{BRANCH_PREFIX}-{input_data.version}",
132-
git_url=input_data.git_url,
133-
dist_git_branch=input_data.dist_git_branch,
134-
)
135-
136-
return PromptTemplate(
137-
PromptTemplateInput(schema=InputSchema, template=template, functions={"rebase_git_steps": rebase_git_steps})
138-
).render(input)
97+
return PromptTemplate(PromptTemplateInput(schema=InputSchema, template=template)).render(input)
13998

14099

141100
async def main() -> None:
@@ -144,28 +103,112 @@ async def main() -> None:
144103
setup_observability(os.getenv("COLLECTOR_ENDPOINT"))
145104

146105
async with mcp_tools(os.getenv("MCP_GATEWAY_URL")) as gateway_tools:
147-
agent = RequirementAgent(
106+
rebase_agent = RequirementAgent(
148107
llm=ChatModel.from_name(os.getenv("CHAT_MODEL")),
149-
tools=[ThinkTool(), RunShellCommandTool(), DuckDuckGoSearchTool()]
150-
+ [
151-
t
152-
for t in gateway_tools
153-
if t.name in ("fork_repository", "open_merge_request", "push_to_remote_repository")
108+
tools=[
109+
ThinkTool(),
110+
RunShellCommandTool(),
111+
DuckDuckGoSearchTool(),
112+
CreateTool(),
113+
ViewTool(),
114+
InsertTool(),
115+
StrReplaceTool(),
116+
AddChangelogEntryTool(),
154117
],
155118
memory=UnconstrainedMemory(),
156119
requirements=[
157120
ConditionalRequirement(ThinkTool, force_after=Tool, consecutive_allowed=False),
158121
],
159122
middlewares=[GlobalTrajectoryMiddleware(pretty=True)],
123+
role="Red Hat Enterprise Linux developer",
124+
instructions=[
125+
"Use the `think` tool to reason through complex decisions and document your approach.",
126+
"Preserve existing formatting and style conventions in RPM spec files and patch headers.",
127+
"Use `rpmlint *.spec` to check for packaging issues and address any NEW errors",
128+
"Ignore pre-existing rpmlint warnings unless they're related to your changes",
129+
"Run `centpkg prep` to verify all patches apply cleanly during build preparation",
130+
"Generate an SRPM using `centpkg srpm` command to ensure complete build readiness",
131+
"* IMPORTANT: Only perform changes relevant to the rebase update",
132+
],
160133
)
161134

162-
async def run(input):
163-
response = await agent.run(
164-
prompt=render_prompt(input),
165-
expected_output=OutputSchema,
166-
execution=get_agent_execution_config(),
135+
class State(BaseModel):
136+
jira_issue: str
137+
package: str
138+
dist_git_branch: str
139+
version: str
140+
local_clone: Path | None = Field(default=None)
141+
update_branch: str | None = Field(default=None)
142+
fork_url: str | None = Field(default=None)
143+
rebase_result: OutputSchema | None = Field(default=None)
144+
merge_request_url: str | None = Field(default=None)
145+
146+
workflow = Workflow(State)
147+
148+
async def fork_and_prepare_dist_git(state):
149+
state.local_clone, state.update_branch, state.fork_url = await tasks.fork_and_prepare_dist_git(
150+
jira_issue=state.jira_issue,
151+
package=state.package,
152+
dist_git_branch=state.dist_git_branch,
153+
available_tools=gateway_tools,
167154
)
168-
return OutputSchema.model_validate_json(response.answer.text)
155+
return "run_rebase_agent"
156+
157+
async def run_rebase_agent(state):
158+
cwd = Path.cwd()
159+
try:
160+
# make things easier for the LLM
161+
os.chdir(state.local_clone)
162+
response = await rebase_agent.run(
163+
prompt=render_prompt(
164+
InputSchema(
165+
local_clone=state.local_clone,
166+
package=state.package,
167+
dist_git_branch=state.dist_git_branch,
168+
version=state.version,
169+
jira_issue=state.jira_issue,
170+
),
171+
),
172+
expected_output=OutputSchema,
173+
execution=get_agent_execution_config(),
174+
)
175+
state.rebase_result = OutputSchema.model_validate_json(response.answer.text)
176+
finally:
177+
os.chdir(cwd)
178+
if state.rebase_result.success:
179+
return "commit_push_and_open_mr"
180+
else:
181+
return Workflow.END
182+
183+
async def commit_push_and_open_mr(state):
184+
state.merge_request_url = await tasks.commit_push_and_open_mr(
185+
local_clone=state.local_clone,
186+
files_to_commit="*.spec",
187+
commit_message=f"{COMMIT_PREFIX} Update to version {state.version}",
188+
fork_url=state.fork_url,
189+
dist_git_branch=state.dist_git_branch,
190+
update_branch=state.update_branch,
191+
mr_title=f"{COMMIT_PREFIX} Update to version {state.version}",
192+
mr_description="TODO",
193+
available_tools=gateway_tools,
194+
commit_only=os.getenv("DRY_RUN", "False").lower() == "true",
195+
)
196+
return Workflow.END
197+
198+
workflow.add_step("fork_and_prepare_dist_git", fork_and_prepare_dist_git)
199+
workflow.add_step("run_rebase_agent", run_rebase_agent)
200+
workflow.add_step("commit_push_and_open_mr", commit_push_and_open_mr)
201+
202+
async def run_workflow(package, dist_git_branch, version, jira_issue):
203+
response = await workflow.run(
204+
State(
205+
package=package,
206+
dist_git_branch=dist_git_branch,
207+
version=version,
208+
jira_issue=jira_issue,
209+
),
210+
)
211+
return response.state
169212

170213
if (
171214
(package := os.getenv("PACKAGE", None))
@@ -174,14 +217,13 @@ async def run(input):
174217
and (branch := os.getenv("BRANCH", None))
175218
):
176219
logger.info("Running in direct mode with environment variables")
177-
input = InputSchema(
220+
state = await run_workflow(
178221
package=package,
222+
dist_git_branch=branch,
179223
version=version,
180224
jira_issue=jira_issue,
181-
dist_git_branch=branch,
182225
)
183-
output = await run(input)
184-
logger.info(f"Direct run completed: {output.model_dump_json(indent=4)}")
226+
logger.info(f"Direct run completed: {state.rebase_result.model_dump_json(indent=4)}")
185227
return
186228

187229
class Task(BaseModel):
@@ -211,13 +253,6 @@ class Task(BaseModel):
211253
f"attempt: {task.attempts + 1}"
212254
)
213255

214-
input = InputSchema(
215-
package=rebase_data.package,
216-
version=rebase_data.version,
217-
jira_issue=rebase_data.jira_issue,
218-
dist_git_branch=rebase_data.branch,
219-
)
220-
221256
async def retry(task, error):
222257
task.attempts += 1
223258
if task.attempts < max_retries:
@@ -235,21 +270,26 @@ async def retry(task, error):
235270

236271
try:
237272
logger.info(f"Starting rebase processing for {rebase_data.jira_issue}")
238-
output = await run(input)
273+
state = await run_workflow(
274+
package=rebase_data.package,
275+
dist_git_branch=rebase_data.branch,
276+
version=rebase_data.version,
277+
jira_issue=rebase_data.jira_issue,
278+
)
239279
logger.info(
240-
f"Rebase processing completed for {rebase_data.jira_issue}, " f"success: {output.success}"
280+
f"Rebase processing completed for {rebase_data.jira_issue}, " f"success: {state.rebase_result.success}"
241281
)
242282
except Exception as e:
243283
error = "".join(traceback.format_exception(e))
244284
logger.error(f"Exception during rebase processing for {rebase_data.jira_issue}: {error}")
245-
await retry(task, ErrorData(details=error, jira_issue=input.jira_issue).model_dump_json())
285+
await retry(task, ErrorData(details=error, jira_issue=rebase_data.jira_issue).model_dump_json())
246286
else:
247-
if output.success:
287+
if state.rebase_result.success:
248288
logger.info(f"Rebase successful for {rebase_data.jira_issue}, " f"adding to completed list")
249-
await redis.lpush("completed_rebase_list", output.model_dump_json())
289+
await redis.lpush("completed_rebase_list", state.rebase_result.model_dump_json())
250290
else:
251-
logger.warning(f"Rebase failed for {rebase_data.jira_issue}: {output.error}")
252-
await retry(task, output.error)
291+
logger.warning(f"Rebase failed for {rebase_data.jira_issue}: {state.rebase_result.error}")
292+
await retry(task, state.rebase_result.error)
253293

254294

255295
if __name__ == "__main__":

0 commit comments

Comments
 (0)