1
1
import asyncio
2
2
import logging
3
3
import os
4
- from shutil import rmtree
5
- from pathlib import Path
6
4
import subprocess
7
5
import sys
8
6
import traceback
9
- from typing import Optional
7
+ from pathlib import Path
10
8
11
9
from pydantic import BaseModel , Field
12
10
22
20
from beeai_framework .tools import Tool
23
21
from beeai_framework .tools .search .duckduckgo import DuckDuckGoSearchTool
24
22
from beeai_framework .tools .think import ThinkTool
23
+ from beeai_framework .workflows import Workflow
25
24
25
+ import tasks
26
+ from constants import COMMIT_PREFIX
27
+ from observability import setup_observability
28
+ from tools .commands import RunShellCommandTool
26
29
from tools .specfile import AddChangelogEntryTool , BumpReleaseTool
27
30
from tools .text import CreateTool , InsertTool , StrReplaceTool , ViewTool
28
31
from tools .wicked_git import GitLogSearchTool , GitPatchCreationTool
29
- from constants import COMMIT_PREFIX , BRANCH_PREFIX
30
- from observability import setup_observability
31
- from tools .commands import RunShellCommandTool
32
32
from triage_agent import BackportData , ErrorData
33
- from utils import get_agent_execution_config , mcp_tools , redis_client , get_git_finalization_steps
33
+ from utils import check_subprocess , get_agent_execution_config , mcp_tools , redis_client
34
34
35
35
logger = logging .getLogger (__name__ )
36
36
37
37
38
38
class InputSchema (BaseModel ):
39
+ local_clone : Path = Field (description = "Path to the local clone of forked dist-git repository" )
40
+ unpacked_sources : Path = Field (description = "Path to the unpacked (using `centpkg prep`) sources" )
39
41
package : str = Field (description = "Package to update" )
42
+ dist_git_branch : str = Field (description = "Git branch in dist-git to be updated" )
40
43
upstream_fix : str = Field (description = "Link to an upstream fix for the issue" )
41
44
jira_issue : str = Field (description = "Jira issue to reference as resolved" )
42
45
cve_id : str = Field (default = "" , description = "CVE ID if the jira issue is a CVE" )
43
- dist_git_branch : str = Field (description = "Git branch in dist-git to be updated" )
44
- git_repo_basepath : str = Field (
45
- description = "Base path for cloned git repos" ,
46
- default = os .getenv ("GIT_REPO_BASEPATH" ),
47
- )
48
- unpacked_sources : str = Field (
49
- description = "Path to the unpacked (using `centpkg prep`) sources" ,
50
- default = "" ,
51
- )
52
46
53
47
54
48
class OutputSchema (BaseModel ):
55
49
success : bool = Field (description = "Whether the backport was successfully completed" )
56
50
status : str = Field (description = "Backport status" )
57
- mr_url : Optional [ str ] = Field (description = "URL to the opened merge request" )
58
- error : Optional [ str ] = Field (description = "Specific details about an error" )
51
+ mr_url : str | None = Field (description = "URL to the opened merge request" )
52
+ error : str | None = Field (description = "Specific details about an error" )
59
53
60
54
61
55
def render_prompt (input : InputSchema ) -> str :
62
56
template = (
63
- 'Work inside the repository cloned at "{{ git_repo_basepath }}/{{ package }}" \n '
57
+ 'Work inside the repository cloned in "{{ local_clone }}", it is your current working directory \n '
64
58
"Use the `git_log_search` tool to check if the jira issue ({{ jira_issue }}) or CVE ({{ cve_id }}) is already resolved.\n "
65
59
"If the issue or the cve are already resolved, exit the backporting process with success=True and status=\" Backport already applied\" \n "
66
60
"Download the upstream fix from {{ upstream_fix }}\n "
@@ -74,64 +68,8 @@ def render_prompt(input: InputSchema) -> str:
74
68
"Delete all *.rej files\n "
75
69
"DO **NOT** RUN COMMAND `git am --continue`\n "
76
70
"Once you resolve all conflicts, use tool git_patch_create to create a patch file\n "
77
- "{{ backport_git_steps }}"
78
71
)
79
-
80
- # Define template function that can be called from the template
81
- def backport_git_steps (data ):
82
- input_data = InputSchema .model_validate (data )
83
- return get_git_finalization_steps (
84
- package = input_data .package ,
85
- jira_issue = input_data .jira_issue ,
86
- commit_title = f"{ COMMIT_PREFIX } backport { input_data .jira_issue } " ,
87
- files_to_commit = f"*.spec and { input_data .jira_issue } .patch" ,
88
- branch_name = f"{ BRANCH_PREFIX } -{ input_data .jira_issue } " ,
89
- dist_git_branch = input_data .dist_git_branch ,
90
- )
91
-
92
- return PromptTemplate (
93
- PromptTemplateInput (schema = InputSchema , template = template , functions = {"backport_git_steps" : backport_git_steps })
94
- ).render (input )
95
-
96
-
97
- def prepare_package (
98
- package : str , jira_issue : str , dist_git_branch : str , input_schema : InputSchema
99
- ) -> tuple [Path , Path ]:
100
- """
101
- Prepare the package for backporting by cloning the dist-git repository, switching to the appropriate branch,
102
- and downloading the sources.
103
- Returns the path to the unpacked sources.
104
- """
105
- git_repo = Path (input_schema .git_repo_basepath )
106
- git_repo .mkdir (parents = True , exist_ok = True )
107
- subprocess .check_call (
108
- [
109
- "centpkg" ,
110
- "clone" ,
111
- "--anonymous" ,
112
- "--branch" ,
113
- dist_git_branch ,
114
- package ,
115
- ],
116
- cwd = git_repo ,
117
- )
118
- local_clone = git_repo / package
119
- subprocess .check_call (
120
- [
121
- "git" ,
122
- "switch" ,
123
- "-c" ,
124
- f"automated-package-update-{ jira_issue } " ,
125
- dist_git_branch ,
126
- ],
127
- cwd = local_clone ,
128
- )
129
- subprocess .check_call (["centpkg" , "sources" ], cwd = local_clone )
130
- subprocess .check_call (["centpkg" , "prep" ], cwd = local_clone )
131
- unpacked_sources = list (local_clone .glob (f"*-build/*{ package } *" ))
132
- if len (unpacked_sources ) != 1 :
133
- raise ValueError (f"Expected exactly one unpacked source, got { unpacked_sources } " )
134
- return unpacked_sources [0 ], local_clone
72
+ return PromptTemplate (PromptTemplateInput (schema = InputSchema , template = template )).render (input )
135
73
136
74
137
75
async def main () -> None :
@@ -141,7 +79,7 @@ async def main() -> None:
141
79
cve_id = os .getenv ("CVE_ID" , "" )
142
80
143
81
async with mcp_tools (os .getenv ("MCP_GATEWAY_URL" )) as gateway_tools :
144
- agent = RequirementAgent (
82
+ backport_agent = RequirementAgent (
145
83
llm = ChatModel .from_name (os .getenv ("CHAT_MODEL" )),
146
84
tools = [
147
85
ThinkTool (),
@@ -155,11 +93,6 @@ async def main() -> None:
155
93
GitLogSearchTool (),
156
94
BumpReleaseTool (),
157
95
AddChangelogEntryTool (),
158
- ]
159
- + [
160
- t
161
- for t in gateway_tools
162
- if t .name in ("fork_repository" , "open_merge_request" , "push_to_remote_repository" )
163
96
],
164
97
memory = UnconstrainedMemory (),
165
98
requirements = [
@@ -182,41 +115,110 @@ async def main() -> None:
182
115
],
183
116
)
184
117
185
- dry_run = os .getenv ("DRY_RUN" , "False" ).lower () == "true"
118
+ class State (BaseModel ):
119
+ jira_issue : str
120
+ package : str
121
+ dist_git_branch : str
122
+ upstream_fix : str
123
+ cve_id : str
124
+ local_clone : Path | None = Field (default = None )
125
+ update_branch : str | None = Field (default = None )
126
+ fork_url : str | None = Field (default = None )
127
+ unpacked_sources : Path | None = Field (default = None )
128
+ backport_result : OutputSchema | None = Field (default = None )
129
+ merge_request_url : str | None = Field (default = None )
130
+
131
+ workflow = Workflow (State )
186
132
187
- async def run (input ):
188
- response = await agent .run (
189
- prompt = render_prompt (input ),
190
- expected_output = OutputSchema ,
191
- execution = get_agent_execution_config (),
133
+ async def fork_and_prepare_dist_git (state ):
134
+ state .local_clone , state .update_branch , state .fork_url = await tasks .fork_and_prepare_dist_git (
135
+ jira_issue = state .jira_issue ,
136
+ package = state .package ,
137
+ dist_git_branch = state .dist_git_branch ,
138
+ available_tools = gateway_tools ,
192
139
)
193
- return OutputSchema .model_validate_json (response .answer .text )
140
+ await check_subprocess (["centpkg" , "sources" ], cwd = state .local_clone )
141
+ await check_subprocess (["centpkg" , "prep" ], cwd = state .local_clone )
142
+ unpacked_sources = list (state .local_clone .glob (f"*-build/*{ state .package } *" ))
143
+ if len (unpacked_sources ) != 1 :
144
+ raise ValueError (f"Expected exactly one unpacked source, got { unpacked_sources } " )
145
+ [state .unpacked_sources ] = unpacked_sources
146
+ return "run_backport_agent"
147
+
148
+ async def run_backport_agent (state ):
149
+ cwd = Path .cwd ()
150
+ try :
151
+ # make things easier for the LLM
152
+ os .chdir (state .local_clone )
153
+ response = await backport_agent .run (
154
+ prompt = render_prompt (
155
+ InputSchema (
156
+ local_clone = state .local_clone ,
157
+ unpacked_sources = state .unpacked_sources ,
158
+ package = state .package ,
159
+ dist_git_branch = state .dist_git_branch ,
160
+ upstream_fix = state .upstream_fix ,
161
+ jira_issue = state .jira_issue ,
162
+ cve_id = state .cve_id ,
163
+ ),
164
+ ),
165
+ expected_output = OutputSchema ,
166
+ execution = get_agent_execution_config (),
167
+ )
168
+ state .backport_result = OutputSchema .model_validate_json (response .answer .text )
169
+ finally :
170
+ os .chdir (cwd )
171
+ if state .backport_result .success :
172
+ return "commit_push_and_open_mr"
173
+ else :
174
+ return Workflow .END
175
+
176
+ async def commit_push_and_open_mr (state ):
177
+ state .merge_request_url = await tasks .commit_push_and_open_mr (
178
+ local_clone = state .local_clone ,
179
+ files_to_commit = ["*.spec" , f"{ state .jira_issue } .patch" ],
180
+ commit_message = f"{ COMMIT_PREFIX } backport { state .jira_issue } " ,
181
+ fork_url = state .fork_url ,
182
+ dist_git_branch = state .dist_git_branch ,
183
+ update_branch = state .update_branch ,
184
+ mr_title = "{COMMIT_PREFIX} backport {state.jira_issue}" ,
185
+ mr_description = "TODO" ,
186
+ available_tools = gateway_tools ,
187
+ commit_only = os .getenv ("DRY_RUN" , "False" ).lower () == "true" ,
188
+ )
189
+ return Workflow .END
190
+
191
+ workflow .add_step ("fork_and_prepare_dist_git" , fork_and_prepare_dist_git )
192
+ workflow .add_step ("run_backport_agent" , run_backport_agent )
193
+ workflow .add_step ("commit_push_and_open_mr" , commit_push_and_open_mr )
194
+
195
+ async def run_workflow (package , dist_git_branch , upstream_fix , jira_issue , cve_id ):
196
+ response = await workflow .run (
197
+ State (
198
+ package = package ,
199
+ dist_git_branch = dist_git_branch ,
200
+ upstream_fix = upstream_fix ,
201
+ jira_issue = jira_issue ,
202
+ cve_id = cve_id ,
203
+ ),
204
+ )
205
+ return response .state
194
206
195
207
if (
196
208
(package := os .getenv ("PACKAGE" , None ))
209
+ and (branch := os .getenv ("BRANCH" , None ))
197
210
and (upstream_fix := os .getenv ("UPSTREAM_FIX" , None ))
198
211
and (jira_issue := os .getenv ("JIRA_ISSUE" , None ))
199
- and (branch := os .getenv ("BRANCH" , None ))
200
212
):
201
213
logger .info ("Running in direct mode with environment variables" )
202
- input = InputSchema (
214
+ state = await run_workflow (
203
215
package = package ,
216
+ dist_git_branch = branch ,
204
217
upstream_fix = upstream_fix ,
205
218
jira_issue = jira_issue ,
206
- dist_git_branch = branch ,
207
- cve_id = cve_id ,
219
+ cve_id = os .getenv ("CVE_ID" , "" ),
208
220
)
209
- unpacked_sources , local_clone = prepare_package (package , jira_issue , branch , input )
210
- input .unpacked_sources = str (unpacked_sources )
211
- try :
212
- output = await run (input )
213
- finally :
214
- if not dry_run :
215
- logger .info (f"Removing { local_clone } " )
216
- rmtree (local_clone )
217
- else :
218
- logger .info (f"DRY RUN: Not removing { local_clone } " )
219
- logger .info (f"Direct run completed: { output .model_dump_json (indent = 4 )} " )
221
+ logger .info (f"Direct run completed: { state .backport_result .model_dump_json (indent = 4 )} " )
220
222
return
221
223
222
224
class Task (BaseModel ):
@@ -245,18 +247,6 @@ class Task(BaseModel):
245
247
f"JIRA: { backport_data .jira_issue } , attempt: { task .attempts + 1 } "
246
248
)
247
249
248
- input = InputSchema (
249
- package = backport_data .package ,
250
- upstream_fix = backport_data .patch_url ,
251
- jira_issue = backport_data .jira_issue ,
252
- dist_git_branch = backport_data .branch ,
253
- cve_id = backport_data .cve_id ,
254
- )
255
- unpacked_sources , local_clone = prepare_package (
256
- backport_data .package , backport_data .jira_issue , backport_data .branch , input
257
- )
258
- input .unpacked_sources = str (unpacked_sources )
259
-
260
250
async def retry (task , error ):
261
251
task .attempts += 1
262
252
if task .attempts < max_retries :
@@ -274,23 +264,29 @@ async def retry(task, error):
274
264
275
265
try :
276
266
logger .info (f"Starting backport processing for { backport_data .jira_issue } " )
277
- output = await run (input )
267
+ state = await run_workflow (
268
+ package = backport_data .package ,
269
+ dist_git_branch = backport_data .branch ,
270
+ upstream_fix = backport_data .patch_url ,
271
+ jira_issue = backport_data .jira_issue ,
272
+ cve_id = backport_data .cve_id ,
273
+ )
278
274
logger .info (
279
- f"Backport processing completed for { backport_data .jira_issue } , " f"success: { output .success } "
275
+ f"Backport processing completed for { backport_data .jira_issue } , " f"success: { state . backport_result .success } "
280
276
)
281
277
except Exception as e :
282
278
error = "" .join (traceback .format_exception (e ))
283
279
logger .error (f"Exception during backport processing for { backport_data .jira_issue } : { error } " )
284
- await retry (task , ErrorData (details = error , jira_issue = input .jira_issue ).model_dump_json ())
280
+ await retry (task , ErrorData (details = error , jira_issue = backport_data .jira_issue ).model_dump_json ())
285
281
rmtree (local_clone )
286
282
else :
287
283
rmtree (local_clone )
288
- if output .success :
284
+ if state . backport_data .success :
289
285
logger .info (f"Backport successful for { backport_data .jira_issue } , " f"adding to completed list" )
290
- await redis .lpush ("completed_backport_list" , output .model_dump_json ())
286
+ await redis .lpush ("completed_backport_list" , state . backport_data .model_dump_json ())
291
287
else :
292
- logger .warning (f"Backport failed for { backport_data .jira_issue } : { output .error } " )
293
- await retry (task , output .error )
288
+ logger .warning (f"Backport failed for { backport_data .jira_issue } : { state . backport_data .error } " )
289
+ await retry (task , state . backport_data .error )
294
290
295
291
296
292
if __name__ == "__main__" :
0 commit comments