Skip to content

Commit ee6f860

Browse files
Merge pull request #5 from kevinbackhouse/available-tools
Refactor available tools into a class so that it can be passed around…
2 parents 1270c11 + 0b7440f commit ee6f860

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

available_tools.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class AvailableTools:
2+
"""
3+
This class is used for storing dictionaries of all the available
4+
personalities, taskflows, and prompts.
5+
"""
6+
def __init__(self, personalities: dict, taskflows: dict, prompts: dict):
7+
self.personalities = personalities
8+
self.taskflows = taskflows
9+
self.prompts = prompts

main.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from yaml_parser import YamlParser
3030
from agent import TaskAgent
3131
from capi import list_tool_call_models
32+
from available_tools import AvailableTools
3233

3334
load_dotenv()
3435

@@ -48,9 +49,8 @@
4849
MAX_API_RETRY = 5
4950
MCP_CLEANUP_TIMEOUT = 5
5051

51-
def parse_prompt_args(user_prompt: str | None = None):
52-
available_personalities = YamlParser('personalities').get_yaml_dict()
53-
available_taskflows = YamlParser('taskflows').get_yaml_dict()
52+
def parse_prompt_args(available_tools: AvailableTools,
53+
user_prompt: str | None = None):
5454
parser = argparse.ArgumentParser(add_help=False, description="SecLab Taskflow Agent")
5555
parser.prog = ''
5656
group = parser.add_mutually_exclusive_group()
@@ -61,10 +61,10 @@ def parse_prompt_args(user_prompt: str | None = None):
6161
#parser.add_argument('remainder', nargs=argparse.REMAINDER, help="Remaining args")
6262
help_msg = parser.format_help()
6363
help_msg += "\nAvailable Personalities:\n\n"
64-
for k in available_personalities:
64+
for k in available_tools.personalities:
6565
help_msg += f"`{k}`\n"
6666
help_msg += "\nAvailable Taskflows:\n\n"
67-
for k in available_taskflows:
67+
for k in available_tools.taskflows:
6868
help_msg += f"`{k}`\n"
6969
help_msg += "\nExamples:\n\n"
7070
help_msg += "`-p assistant explain modems to me please`\n"
@@ -372,11 +372,8 @@ async def _run_streamed():
372372
logging.error(f"Exception in mcp server cleanup task: {e}")
373373

374374

375-
async def main(p: str | None, t: str | None, prompt: str | None):
376-
377-
available_personalities = YamlParser('personalities').get_yaml_dict()
378-
available_taskflows = YamlParser('taskflows').get_yaml_dict()
379-
available_prompts = YamlParser('prompts').get_yaml_dict(dir_namespace=True)
375+
async def main(available_tools: AvailableTools,
376+
p: str | None, t: str | None, prompt: str | None):
380377
last_mcp_tool_results = [] # XXX: memleaky
381378

382379
async def on_tool_end_hook(
@@ -399,7 +396,7 @@ async def on_handoff_hook(
399396
await render_model_output(f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n")
400397

401398
if p:
402-
personality = available_personalities.get(p)
399+
personality = available_tools.personalities.get(p)
403400
if personality is None:
404401
raise ValueError("No such personality!")
405402

@@ -412,7 +409,7 @@ async def on_handoff_hook(
412409

413410
if t:
414411

415-
taskflow = available_taskflows.get(t)
412+
taskflow = available_tools.taskflows.get(t)
416413
if taskflow is None:
417414
raise ValueError("No such taskflow!")
418415

@@ -431,7 +428,7 @@ async def on_handoff_hook(
431428
# can tweak reusable task configurations as they see fit
432429
uses = task_body.get('uses', '')
433430
if uses:
434-
reusable_taskflow = available_taskflows.get(uses)
431+
reusable_taskflow = available_tools.taskflows.get(uses)
435432
if reusable_taskflow is None:
436433
raise ValueError(f"No such reusable taskflow: {uses}")
437434
if len(reusable_taskflow['taskflow']) > 1:
@@ -479,7 +476,7 @@ def preprocess_prompt(prompt: str, tag: str, kv: dict, kv_subkey=None):
479476

480477
# pre-process the prompt for any prompts
481478
if prompt:
482-
prompt = preprocess_prompt(prompt, 'PROMPTS', available_prompts, 'prompt')
479+
prompt = preprocess_prompt(prompt, 'PROMPTS', available_tools.prompts, 'prompt')
483480

484481
# pre-process the prompt for any inputs
485482
if prompt and inputs:
@@ -566,10 +563,10 @@ async def run_prompts(async_task=False, max_concurrent_tasks=5):
566563
if not agents:
567564
# XXX: deprecate the -p parser for taskflows entirely?
568565
# XXX: probably just adds unneeded parsing complexity
569-
p, _, _, prompt, _ = parse_prompt_args(prompt)
566+
p, _, _, prompt, _ = parse_prompt_args(available_tools, prompt)
570567
agents.append(p)
571568
for p in agents:
572-
personality = available_personalities.get(p)
569+
personality = available_tools.personalities.get(p)
573570
if personality is None:
574571
raise ValueError(f"No such personality: {p}")
575572
resolved_agents[p] = personality
@@ -628,8 +625,12 @@ async def _deploy_task_agents(resolved_agents, prompt):
628625
break
629626

630627
if __name__ == '__main__':
628+
available_tools = AvailableTools(
629+
personalities = YamlParser('personalities').get_yaml_dict(),
630+
taskflows = YamlParser('taskflows').get_yaml_dict(),
631+
prompts = YamlParser('prompts').get_yaml_dict(dir_namespace=True))
631632

632-
p, t, l, user_prompt, help_msg = parse_prompt_args()
633+
p, t, l, user_prompt, help_msg = parse_prompt_args(available_tools)
633634

634635
if l:
635636
tool_models = list_tool_call_models(os.getenv('COPILOT_TOKEN'))
@@ -641,4 +642,4 @@ async def _deploy_task_agents(resolved_agents, prompt):
641642
print(help_msg)
642643
sys.exit(1)
643644

644-
asyncio.run(main(p, t, user_prompt), debug=True)
645+
asyncio.run(main(available_tools, p, t, user_prompt), debug=True)

0 commit comments

Comments
 (0)