Skip to content

Commit f59de8b

Browse files
hatch fmt --check --fix
1 parent 27fe932 commit f59de8b

File tree

21 files changed

+173
-141
lines changed

21 files changed

+173
-141
lines changed

release_tools/copy_files.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33

44
import os
55
import shutil
6-
import sys
76
import subprocess
7+
import sys
8+
89

910
def read_file_list(list_path):
1011
"""
1112
Reads a file containing file paths, ignoring empty lines and lines starting with '#'.
1213
Returns a list of relative file paths.
1314
"""
14-
with open(list_path, "r") as f:
15+
with open(list_path) as f:
1516
lines = [line.strip() for line in f]
1617
return [line for line in lines if line and not line.startswith("#")]
1718

release_tools/publish_docker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import sys
88
import tempfile
99

10+
1011
def read_file_list(list_path):
1112
"""
1213
Reads a file containing file paths, ignoring empty lines and lines starting with '#'.
1314
Returns a list of relative file paths.
1415
"""
15-
with open(list_path, "r") as f:
16+
with open(list_path) as f:
1617
lines = [line.strip() for line in f]
1718
return [line for line in lines if line and not line.startswith("#")]
1819

src/seclab_taskflow_agent/__main__.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,45 @@
11
# SPDX-FileCopyrightText: 2025 GitHub
22
# SPDX-License-Identifier: MIT
33

4-
import asyncio
5-
from threading import Thread
64
import argparse
7-
import os
8-
import sys
9-
from dotenv import load_dotenv
5+
import asyncio
6+
import json
107
import logging
11-
from logging.handlers import RotatingFileHandler
12-
from pprint import pprint, pformat
8+
import os
9+
import pathlib
1310
import re
14-
import json
11+
import sys
1512
import uuid
16-
import pathlib
13+
from logging.handlers import RotatingFileHandler
14+
from pprint import pformat
15+
from typing import Callable
1716

18-
from .agent import DEFAULT_MODEL, TaskRunHooks, TaskAgentHooks
19-
#from agents.run import DEFAULT_MAX_TURNS # XXX: this is 10, we need more than that
20-
from agents.exceptions import MaxTurnsExceeded, AgentsException
17+
from agents import Agent, RunContextWrapper, TContext, Tool
2118
from agents.agent import ModelSettings
22-
from agents.mcp import MCPServer, MCPServerStdio, MCPServerSse, MCPServerStreamableHttp, create_static_tool_filter
19+
20+
#from agents.run import DEFAULT_MAX_TURNS # XXX: this is 10, we need more than that
21+
from agents.exceptions import AgentsException, MaxTurnsExceeded
2322
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
24-
from agents import Tool, RunContextWrapper, TContext, Agent
25-
from openai import BadRequestError, APITimeoutError, RateLimitError
23+
from agents.mcp import MCPServerSse, MCPServerStdio, MCPServerStreamableHttp, create_static_tool_filter
24+
from dotenv import load_dotenv
25+
from openai import APITimeoutError, BadRequestError, RateLimitError
2626
from openai.types.responses import ResponseTextDeltaEvent
27-
from typing import Callable
2827

29-
from .shell_utils import shell_tool_call
30-
from .mcp_utils import DEFAULT_MCP_CLIENT_SESSION_TIMEOUT, ReconnectingMCPServerStdio, AsyncDebugMCPServerStdio, MCPNamespaceWrap, mcp_client_params, mcp_system_prompt, StreamableMCPThread, compress_name
31-
from .render_utils import render_model_output, flush_async_output
32-
from .env_utils import TmpEnv
33-
from .agent import TaskAgent
34-
from .capi import list_tool_call_models
28+
from .agent import DEFAULT_MODEL, TaskAgent, TaskAgentHooks, TaskRunHooks
3529
from .available_tools import AvailableTools
30+
from .capi import list_tool_call_models
31+
from .env_utils import TmpEnv
32+
from .mcp_utils import (
33+
DEFAULT_MCP_CLIENT_SESSION_TIMEOUT,
34+
MCPNamespaceWrap,
35+
ReconnectingMCPServerStdio,
36+
StreamableMCPThread,
37+
compress_name,
38+
mcp_client_params,
39+
mcp_system_prompt,
40+
)
41+
from .render_utils import flush_async_output, render_model_output
42+
from .shell_utils import shell_tool_call
3643

3744
load_dotenv()
3845

@@ -74,7 +81,7 @@ def parse_prompt_args(available_tools: AvailableTools,
7481
args = parser.parse_known_args(user_prompt.split(' ') if user_prompt else None)
7582
except SystemExit as e:
7683
if e.code == 2:
77-
logging.error(f"User provided incomplete prompt: {user_prompt}")
84+
logging.exception(f"User provided incomplete prompt: {user_prompt}")
7885
return None, None, None, help_msg
7986
p = args[0].p.strip() if args[0].p else None
8087
t = args[0].t.strip() if args[0].t else None
@@ -218,14 +225,13 @@ async def mcp_session_task(
218225
except Exception as e:
219226
print(f"Streamable mcp server process exception: {e}")
220227
except asyncio.CancelledError:
221-
logging.error(f"Timeout on cleanup for mcp server: {server._name}")
228+
logging.exception(f"Timeout on cleanup for mcp server: {server._name}")
222229
finally:
223230
mcp_servers.remove(s)
224231
except RuntimeError as e:
225-
logging.error(f"RuntimeError in mcp session task: {e}")
232+
logging.exception(f"RuntimeError in mcp session task: {e}")
226233
except asyncio.CancelledError as e:
227-
logging.error(f"Timeout on main session task: {e}")
228-
pass
234+
logging.exception(f"Timeout on main session task: {e}")
229235
finally:
230236
mcp_servers.clear()
231237

@@ -318,17 +324,17 @@ async def _run_streamed():
318324
return
319325
except APITimeoutError:
320326
if not max_retry:
321-
logging.error(f"Max retries for APITimeoutError reached")
327+
logging.exception("Max retries for APITimeoutError reached")
322328
raise
323329
max_retry -= 1
324330
except RateLimitError:
325331
if rate_limit_backoff == MAX_RATE_LIMIT_BACKOFF:
326-
raise APITimeoutError(f"Max rate limit backoff reached")
332+
raise APITimeoutError("Max rate limit backoff reached")
327333
if rate_limit_backoff > MAX_RATE_LIMIT_BACKOFF:
328334
rate_limit_backoff = MAX_RATE_LIMIT_BACKOFF
329335
else:
330336
rate_limit_backoff += rate_limit_backoff
331-
logging.error(f"Hit rate limit ... holding for {rate_limit_backoff}")
337+
logging.exception(f"Hit rate limit ... holding for {rate_limit_backoff}")
332338
await asyncio.sleep(rate_limit_backoff)
333339
await _run_streamed()
334340
complete = True
@@ -338,22 +344,22 @@ async def _run_streamed():
338344
await render_model_output(f"** 🤖❗ Max Turns Reached: {e}\n",
339345
async_task=async_task,
340346
task_id=task_id)
341-
logging.error(f"Exceeded max_turns: {max_turns}")
347+
logging.exception(f"Exceeded max_turns: {max_turns}")
342348
except AgentsException as e:
343349
await render_model_output(f"** 🤖❗ Agent Exception: {e}\n",
344350
async_task=async_task,
345351
task_id=task_id)
346-
logging.error(f"Agent Exception: {e}")
352+
logging.exception(f"Agent Exception: {e}")
347353
except BadRequestError as e:
348354
await render_model_output(f"** 🤖❗ Request Error: {e}\n",
349355
async_task=async_task,
350356
task_id=task_id)
351-
logging.error(f"Bad Request: {e}")
357+
logging.exception(f"Bad Request: {e}")
352358
except APITimeoutError as e:
353359
await render_model_output(f"** 🤖❗ Timeout Error: {e}\n",
354360
async_task=async_task,
355361
task_id=task_id)
356-
logging.error(f"Bad Request: {e}")
362+
logging.exception(f"Bad Request: {e}")
357363

358364
if async_task:
359365
await flush_async_output(task_id)
@@ -369,10 +375,10 @@ async def _run_streamed():
369375
try:
370376
cleanup_attempts_left -= 1
371377
await asyncio.wait_for(mcp_sessions, timeout=MCP_CLEANUP_TIMEOUT)
372-
except asyncio.TimeoutError as e:
378+
except asyncio.TimeoutError:
373379
continue
374380
except Exception as e:
375-
logging.error(f"Exception in mcp server cleanup task: {e}")
381+
logging.exception(f"Exception in mcp server cleanup task: {e}")
376382

377383

378384
async def main(available_tools: AvailableTools,
@@ -425,7 +431,7 @@ async def on_handoff_hook(
425431
if model_dict:
426432
if not isinstance(model_dict, dict):
427433
raise ValueError(f"Models section of the model_config file {model_config} must be a dictionary")
428-
model_keys = model_dict.keys()
434+
model_keys = model_dict.keys()
429435

430436
for task in taskflow['taskflow']:
431437

@@ -557,15 +563,15 @@ async def run_prompts(async_task=False, max_concurrent_tasks=5):
557563

558564
# if this is a shell task, execute that and append the results
559565
if run:
560-
await render_model_output(f"** 🤖🐚 Executing Shell Task\n")
566+
await render_model_output("** 🤖🐚 Executing Shell Task\n")
561567
# this allows e.g. shell based jq output to become available for repeat prompts
562568
try:
563569
result = shell_tool_call(run).content[0].model_dump_json()
564570
last_mcp_tool_results.append(result)
565571
return True
566572
except RuntimeError as e:
567573
await render_model_output(f"** 🤖❗ Shell Task Exception: {e}\n")
568-
logging.error(f"Shell task error: {e}")
574+
logging.exception(f"Shell task error: {e}")
569575
return False
570576

571577
tasks = []

src/seclab_taskflow_agent/agent.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,32 @@
22
# SPDX-License-Identifier: MIT
33

44
# https://openai.github.io/openai-agents-python/agents/
5-
import os
65
import logging
7-
from dotenv import load_dotenv
6+
import os
87
from collections.abc import Callable
98
from typing import Any
109
from urllib.parse import urlparse
1110

11+
from agents import (
12+
Agent,
13+
AgentHooks,
14+
OpenAIChatCompletionsModel,
15+
RunContextWrapper,
16+
RunHooks,
17+
Runner,
18+
TContext,
19+
Tool,
20+
result,
21+
set_default_openai_api,
22+
set_default_openai_client,
23+
set_tracing_disabled,
24+
)
25+
from agents.agent import FunctionToolResult, ModelSettings, ToolsToFinalOutputResult
26+
from agents.run import DEFAULT_MAX_TURNS, RunHooks
27+
from dotenv import load_dotenv
1228
from openai import AsyncOpenAI
13-
from agents.agent import ModelSettings, ToolsToFinalOutputResult, FunctionToolResult
14-
from agents.run import DEFAULT_MAX_TURNS
15-
from agents.run import RunHooks
16-
from agents import Agent, Runner, AgentHooks, RunHooks, result, function_tool, Tool, RunContextWrapper, TContext, OpenAIChatCompletionsModel, set_default_openai_client, set_default_openai_api, set_tracing_disabled
1729

18-
from .capi import COPILOT_INTEGRATION_ID, COPILOT_API_ENDPOINT
30+
from .capi import COPILOT_API_ENDPOINT, COPILOT_INTEGRATION_ID
1931

2032
# grab our secrets from .env, this must be in .gitignore
2133
load_dotenv()

src/seclab_taskflow_agent/available_tools.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# SPDX-FileCopyrightText: 2025 GitHub
22
# SPDX-License-Identifier: MIT
33

4-
from enum import Enum
5-
import logging
64
import importlib.resources
5+
from enum import Enum
6+
77
import yaml
88

9+
910
class BadToolNameError(Exception):
1011
pass
1112

@@ -74,7 +75,7 @@ def get_tool(self, tooltype: AvailableToolType, toolname: str):
7475
version = header['version']
7576
if version != 1:
7677
raise VersionException(str(version))
77-
filetype = header['filetype']
78+
filetype = header['filetype']
7879
if filetype != tooltype.value:
7980
raise FileTypeException(
8081
f'Error in {f}: expected filetype to be {tooltype}, but it\'s {filetype}.')

src/seclab_taskflow_agent/capi.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
# SPDX-License-Identifier: MIT
33

44
# CAPI specific interactions
5-
import httpx
65
import json
76
import logging
87
import os
98
from urllib.parse import urlparse
109

10+
import httpx
11+
1112
# you can also set https://models.github.ai/inference if you prefer
1213
# but beware that your taskflows need to reference the correct model id
1314
# since the Modeld API uses it's own id schema, use -l with your desired
@@ -43,11 +44,11 @@ def list_capi_models(token: str) -> dict[str, dict]:
4344
for model in models_list:
4445
models[model.get('id')] = dict(model)
4546
except httpx.RequestError as e:
46-
logging.error(f"Request error: {e}")
47+
logging.exception(f"Request error: {e}")
4748
except json.JSONDecodeError as e:
48-
logging.error(f"JSON error: {e}")
49+
logging.exception(f"JSON error: {e}")
4950
except httpx.HTTPStatusError as e:
50-
logging.error(f"HTTP error: {e}")
51+
logging.exception(f"HTTP error: {e}")
5152
return models
5253

5354
def supports_tool_calls(model: str, models: dict) -> bool:

src/seclab_taskflow_agent/env_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# SPDX-FileCopyrightText: 2025 GitHub
22
# SPDX-License-Identifier: MIT
33

4-
import re
54
import os
5+
import re
6+
67

78
def swap_env(s):
89
match = re.search(r"{{\s*(env)\s+([A-Z0-9_]+)\s*}}", s)

src/seclab_taskflow_agent/mcp_servers/codeql/client.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
# SPDX-License-Identifier: MIT
33

44
# a query-server2 codeql client
5-
import subprocess
6-
import re
75
import json
8-
from pathlib import Path
6+
import os
7+
import re
8+
import subprocess
99
import tempfile
1010
import time
11-
from urllib.parse import urlparse, unquote
12-
import os
1311
import zipfile
12+
from pathlib import Path
13+
from urllib.parse import unquote, urlparse
14+
1415
import yaml
1516

1617
# this is a local fork of https://github.com/riga/jsonrpyc modified for our purposes
@@ -46,7 +47,7 @@ def __init__(self,
4647
self.server_options = server_options.copy()
4748
if log_stderr:
4849
os.makedirs("logs", exist_ok=True)
49-
self.stderr_log = f"logs/codeql_stderr_log.log"
50+
self.stderr_log = "logs/codeql_stderr_log.log"
5051
self.server_options.append("--log-to-stderr")
5152
else:
5253
self.stderr_log = os.devnull
@@ -271,7 +272,7 @@ def _search_path(self):
271272

272273
def _search_paths_from_codeql_config(self, config="~/.config/codeql/config"):
273274
try:
274-
with open(config, 'r') as f:
275+
with open(config) as f:
275276
match = re.search(r"^--search-path(\s+|=)\s*(.*)", f.read())
276277
if match and match.group(2):
277278
return match.group(2).split(':')
@@ -530,7 +531,7 @@ def _file_from_src_archive(relative_path: str | Path, database_path: str | Path,
530531
# fall back to relative path if resolved_path does not exist (might be a build dep file)
531532
if str(resolved_path) not in files:
532533
resolved_path = Path(relative_path)
533-
file_data = shell_command_to_string(["unzip", "-p", src_path, f"{str(resolved_path)}"])
534+
file_data = shell_command_to_string(["unzip", "-p", src_path, f"{resolved_path!s}"])
534535
if region:
535536
def region_from_file():
536537
# regions are 1+ based and look like 1:2:3:4

0 commit comments

Comments
 (0)