Skip to content

Commit 9ac4db3

Browse files
Improvements to the code written by Copilot.
1 parent e29ab42 commit 9ac4db3

File tree

5 files changed

+158
-112
lines changed

5 files changed

+158
-112
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ dependencies = [
9797
"SQLAlchemy==2.0.41",
9898
"sse-starlette==2.4.1",
9999
"starlette==0.49.1",
100+
"strenum==0.4.15",
100101
"tqdm==4.67.1",
101102
"typer==0.16.0",
102103
"types-requests==2.32.4.20250611",

src/seclab_taskflow_agent/capi.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,31 @@
66
import json
77
import logging
88
import os
9+
from strenum import StrEnum
910
from urllib.parse import urlparse
1011

1112
# you can also set https://api.githubcopilot.com if you prefer
1213
# but beware that your taskflows need to reference the correct model id
1314
# since different APIs use their own id schema, use -l with your desired
1415
# endpoint to retrieve the correct id names to use for your taskflow
1516
AI_API_ENDPOINT = os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference')
17+
18+
class AI_API_ENDPOINT_ENUM(StrEnum):
19+
AI_API_MODELS_GITHUB = 'models.github.ai'
20+
AI_API_GITHUBCOPILOT = 'api.githubcopilot.com'
21+
1622
COPILOT_INTEGRATION_ID = 'vscode-chat'
1723

1824
# assume we are >= python 3.9 for our type hints
1925
def list_capi_models(token: str) -> dict[str, dict]:
2026
"""Retrieve a dictionary of available CAPI models"""
2127
models = {}
2228
try:
23-
match urlparse(AI_API_ENDPOINT).netloc:
24-
case 'api.githubcopilot.com':
29+
netloc = urlparse(AI_API_ENDPOINT).netloc
30+
match netloc:
31+
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
2532
models_catalog = 'models'
26-
case 'models.github.ai':
33+
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
2734
models_catalog = 'catalog/models'
2835
case _:
2936
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
@@ -35,11 +42,13 @@ def list_capi_models(token: str) -> dict[str, dict]:
3542
})
3643
r.raise_for_status()
3744
# CAPI vs Models API
38-
match urlparse(AI_API_ENDPOINT).netloc:
39-
case 'api.githubcopilot.com':
45+
match netloc:
46+
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
4047
models_list = r.json().get('data', [])
41-
case 'models.github.ai':
48+
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
4249
models_list = r.json()
50+
case _:
51+
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
4352
for model in models_list:
4453
models[model.get('id')] = dict(model)
4554
except httpx.RequestError as e:
@@ -52,12 +61,12 @@ def list_capi_models(token: str) -> dict[str, dict]:
5261

5362
def supports_tool_calls(model: str, models: dict) -> bool:
5463
match urlparse(AI_API_ENDPOINT).netloc:
55-
case 'api.githubcopilot.com':
64+
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
5665
return models.get(model, {}).\
5766
get('capabilities', {}).\
5867
get('supports', {}).\
5968
get('tool_calls', False)
60-
case 'models.github.ai':
69+
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
6170
return 'tool-calling' in models.get(model, {}).\
6271
get('capabilities', [])
6372
case _:

tests/test_api_endpoint_config.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-FileCopyrightText: 2025 GitHub
2+
# SPDX-License-Identifier: MIT
3+
4+
"""
5+
Test API endpoint configuration.
6+
"""
7+
8+
import pytest
9+
import tempfile
10+
from pathlib import Path
11+
import yaml
12+
import os
13+
from urllib.parse import urlparse
14+
from seclab_taskflow_agent.available_tools import AvailableTools
15+
16+
class TestAPIEndpoint:
17+
"""Test API endpoint configuration."""
18+
19+
@staticmethod
20+
def _reload_capi_module():
21+
"""Helper method to reload the capi module."""
22+
import importlib
23+
import seclab_taskflow_agent.capi
24+
importlib.reload(seclab_taskflow_agent.capi)
25+
26+
def test_default_api_endpoint(self):
27+
"""Test that default API endpoint is set to models.github.ai/inference."""
28+
from seclab_taskflow_agent.capi import AI_API_ENDPOINT, AI_API_ENDPOINT_ENUM
29+
# When no env var is set, it should default to models.github.ai/inference
30+
# Note: We can't easily test this without manipulating the environment
31+
# so we'll just import and verify the constant exists
32+
assert AI_API_ENDPOINT is not None
33+
assert isinstance(AI_API_ENDPOINT, str)
34+
assert urlparse(AI_API_ENDPOINT).netloc == AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB
35+
36+
def test_api_endpoint_env_override(self):
37+
"""Test that AI_API_ENDPOINT can be overridden by environment variable."""
38+
# Save original env
39+
original_env = os.environ.get('AI_API_ENDPOINT')
40+
41+
try:
42+
# Set custom endpoint
43+
test_endpoint = 'https://test.example.com'
44+
os.environ['AI_API_ENDPOINT'] = test_endpoint
45+
46+
# Reload the module to pick up the new env var
47+
self._reload_capi_module()
48+
49+
from seclab_taskflow_agent.capi import AI_API_ENDPOINT
50+
assert AI_API_ENDPOINT == test_endpoint
51+
finally:
52+
# Restore original env
53+
if original_env is None:
54+
os.environ.pop('AI_API_ENDPOINT', None)
55+
else:
56+
os.environ['AI_API_ENDPOINT'] = original_env
57+
# Reload again to restore original state
58+
self._reload_capi_module()
59+
60+
if __name__ == '__main__':
61+
pytest.main([__file__, '-v'])

tests/test_cli_parser.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# SPDX-FileCopyrightText: 2025 GitHub
2+
# SPDX-License-Identifier: MIT
3+
4+
"""
5+
Test CLI global variable parsing.
6+
"""
7+
8+
import pytest
9+
import tempfile
10+
from pathlib import Path
11+
import yaml
12+
import os
13+
from urllib.parse import urlparse
14+
from seclab_taskflow_agent.available_tools import AvailableTools
15+
16+
class TestCliGlobals:
17+
"""Test CLI global variable parsing."""
18+
19+
def test_parse_single_global(self):
20+
"""Test parsing a single global variable from command line."""
21+
from seclab_taskflow_agent.__main__ import parse_prompt_args
22+
available_tools = AvailableTools()
23+
24+
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
25+
available_tools, "-t example -g fruit=apples")
26+
27+
assert t == "example"
28+
assert cli_globals == {"fruit": "apples"}
29+
assert p is None
30+
assert l is False
31+
32+
def test_parse_multiple_globals(self):
33+
"""Test parsing multiple global variables from command line."""
34+
from seclab_taskflow_agent.__main__ import parse_prompt_args
35+
available_tools = AvailableTools()
36+
37+
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
38+
available_tools, "-t example -g fruit=apples -g color=red")
39+
40+
assert t == "example"
41+
assert cli_globals == {"fruit": "apples", "color": "red"}
42+
assert p is None
43+
assert l is False
44+
45+
def test_parse_global_with_spaces(self):
46+
"""Test parsing global variables with spaces in values."""
47+
from seclab_taskflow_agent.__main__ import parse_prompt_args
48+
available_tools = AvailableTools()
49+
50+
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
51+
available_tools, "-t example -g message=hello world")
52+
53+
assert t == "example"
54+
# "world" becomes part of the prompt, not the value
55+
assert cli_globals == {"message": "hello"}
56+
assert "world" in user_prompt
57+
58+
def test_parse_global_with_equals_in_value(self):
59+
"""Test parsing global variables with equals sign in value."""
60+
from seclab_taskflow_agent.__main__ import parse_prompt_args
61+
available_tools = AvailableTools()
62+
63+
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
64+
available_tools, "-t example -g equation=x=5")
65+
66+
assert t == "example"
67+
assert cli_globals == {"equation": "x=5"}
68+
69+
def test_globals_in_taskflow_file(self):
70+
"""Test that globals can be read from taskflow file."""
71+
available_tools = AvailableTools()
72+
73+
taskflow = available_tools.get_taskflow("tests.data.test_globals_taskflow")
74+
assert 'globals' in taskflow
75+
assert taskflow['globals']['test_var'] == 'default_value'
76+
77+
if __name__ == '__main__':
78+
pytest.main([__file__, '-v'])

tests/test_yaml_parser.py

Lines changed: 1 addition & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pathlib import Path
1313
import yaml
1414
import os
15+
from urllib.parse import urlparse
1516
from seclab_taskflow_agent.available_tools import AvailableTools
1617

1718
class TestYamlParser:
@@ -44,109 +45,5 @@ def test_parse_example_taskflows(self):
4445
assert len(example_task_flow['taskflow']) == 4 # 4 tasks in taskflow
4546
assert example_task_flow['taskflow'][0]['task']['max_steps'] == 20
4647

47-
class TestCliGlobals:
48-
"""Test CLI global variable parsing."""
49-
50-
def test_parse_single_global(self):
51-
"""Test parsing a single global variable from command line."""
52-
from seclab_taskflow_agent.__main__ import parse_prompt_args
53-
available_tools = AvailableTools()
54-
55-
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
56-
available_tools, "-t example -g fruit=apples")
57-
58-
assert t == "example"
59-
assert cli_globals == {"fruit": "apples"}
60-
assert p is None
61-
assert l is False
62-
63-
def test_parse_multiple_globals(self):
64-
"""Test parsing multiple global variables from command line."""
65-
from seclab_taskflow_agent.__main__ import parse_prompt_args
66-
available_tools = AvailableTools()
67-
68-
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
69-
available_tools, "-t example -g fruit=apples -g color=red")
70-
71-
assert t == "example"
72-
assert cli_globals == {"fruit": "apples", "color": "red"}
73-
assert p is None
74-
assert l is False
75-
76-
def test_parse_global_with_spaces(self):
77-
"""Test parsing global variables with spaces in values."""
78-
from seclab_taskflow_agent.__main__ import parse_prompt_args
79-
available_tools = AvailableTools()
80-
81-
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
82-
available_tools, "-t example -g message=hello world")
83-
84-
assert t == "example"
85-
# "world" becomes part of the prompt, not the value
86-
assert cli_globals == {"message": "hello"}
87-
assert "world" in user_prompt
88-
89-
def test_parse_global_with_equals_in_value(self):
90-
"""Test parsing global variables with equals sign in value."""
91-
from seclab_taskflow_agent.__main__ import parse_prompt_args
92-
available_tools = AvailableTools()
93-
94-
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
95-
available_tools, "-t example -g equation=x=5")
96-
97-
assert t == "example"
98-
assert cli_globals == {"equation": "x=5"}
99-
100-
def test_globals_in_taskflow_file(self):
101-
"""Test that globals can be read from taskflow file."""
102-
available_tools = AvailableTools()
103-
104-
taskflow = available_tools.get_taskflow("tests.data.test_globals_taskflow")
105-
assert 'globals' in taskflow
106-
assert taskflow['globals']['test_var'] == 'default_value'
107-
108-
class TestAPIEndpoint:
109-
"""Test API endpoint configuration."""
110-
111-
@staticmethod
112-
def _reload_capi_module():
113-
"""Helper method to reload the capi module."""
114-
import importlib
115-
import seclab_taskflow_agent.capi
116-
importlib.reload(seclab_taskflow_agent.capi)
117-
118-
def test_default_api_endpoint(self):
119-
"""Test that default API endpoint is set to models.github.ai/inference."""
120-
from seclab_taskflow_agent.capi import AI_API_ENDPOINT
121-
# When no env var is set, it should default to models.github.ai/inference
122-
# Note: We can't easily test this without manipulating the environment
123-
# so we'll just import and verify the constant exists
124-
assert AI_API_ENDPOINT is not None
125-
assert isinstance(AI_API_ENDPOINT, str)
126-
127-
def test_api_endpoint_env_override(self):
128-
"""Test that AI_API_ENDPOINT can be overridden by environment variable."""
129-
# Save original env
130-
original_env = os.environ.get('AI_API_ENDPOINT')
131-
132-
try:
133-
# Set custom endpoint
134-
test_endpoint = 'https://test.example.com'
135-
os.environ['AI_API_ENDPOINT'] = test_endpoint
136-
137-
# Reload the module to pick up the new env var
138-
self._reload_capi_module()
139-
140-
from seclab_taskflow_agent.capi import AI_API_ENDPOINT
141-
assert AI_API_ENDPOINT == test_endpoint
142-
finally:
143-
# Restore original env
144-
if original_env is None:
145-
os.environ.pop('AI_API_ENDPOINT', None)
146-
else:
147-
os.environ['AI_API_ENDPOINT'] = original_env
148-
# Reload again to restore original state
149-
self._reload_capi_module()
150-
15148
if __name__ == '__main__':
15249
pytest.main([__file__, '-v'])

0 commit comments

Comments
 (0)