Skip to content

Commit 71787c6

Browse files
committed
Add type annotations to functions and methods for improved clarity and maintainabiliy
1 parent 21180b5 commit 71787c6

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

agentic_security/core/app.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
from fastapi import FastAPI
55
from fastapi.responses import ORJSONResponse
66

7+
from agentic_security.http_spec import LLMSpec
8+
from typing import Any, Dict
9+
710
tools_inbox: Queue = Queue()
811
stop_event: Event = Event()
912
current_run: str = {"spec": "", "id": ""}
10-
_secrets = {}
13+
_secrets: dict[str, str] = {}
1114

15+
current_run: Dict[str, int | LLMSpec] = {
16+
"spec": "",
17+
"id": ""
18+
}
1219

1320
def create_app() -> FastAPI:
1421
"""Create and configure the FastAPI application."""
@@ -26,29 +33,29 @@ def get_stop_event() -> Event:
2633
return stop_event
2734

2835

29-
def get_current_run() -> str:
36+
def get_current_run() -> Dict[str, int | LLMSpec]:
3037
"""Get the current run id."""
3138
return current_run
3239

3340

34-
def set_current_run(spec):
41+
def set_current_run(spec : LLMSpec) -> Dict[str, int | LLMSpec]:
3542
"""Set the current run id."""
3643
current_run["id"] = hash(id(spec))
3744
current_run["spec"] = spec
3845
return current_run
3946

4047

41-
def get_secrets():
48+
def get_secrets() -> dict[str, str]:
4249
return _secrets
4350

4451

45-
def set_secrets(secrets):
52+
def set_secrets(secrets : dict[str, str]) -> dict[str, str]:
4653
_secrets.update(secrets)
4754
expand_secrets(_secrets)
4855
return _secrets
4956

5057

51-
def expand_secrets(secrets):
58+
def expand_secrets(secrets : dict[str, str]) -> None:
5259
for key in secrets:
5360
val = secrets[key]
5461
if val.startswith("$"):

agentic_security/misc/banner.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pyfiglet import Figlet, FontNotFound
22
from termcolor import colored
3+
from typing import Optional
34

45
try:
56
from importlib.metadata import version
@@ -8,14 +9,14 @@
89

910

1011
def generate_banner(
11-
title="Agentic Security",
12-
font="slant",
13-
version="v2.1.0",
14-
tagline="Proactive Threat Detection & Automated Security Protocols",
15-
author="Developed by: [Security Team]",
16-
website="Website: https://github.com/msoedov/agentic_security",
17-
warning="",
18-
):
12+
title: str = "Agentic Security",
13+
font: str = "slant",
14+
version: str = "v2.1.0",
15+
tagline: str = "Proactive Threat Detection & Automated Security Protocols",
16+
author: str = "Developed by: [Security Team]",
17+
website: str = "Website: https://github.com/msoedov/agentic_security",
18+
warning: Optional[str] = "", # Using Optional for warning since it might be None
19+
) -> str:
1920
"""Generate a visually enhanced banner with dynamic width and borders."""
2021
# Define the text elements
2122

agentic_security/report_chart.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import io
22
import string
3+
from typing import List
34

45
import matplotlib.pyplot as plt
56
import numpy as np
67
import pandas as pd
78
from matplotlib.cm import ScalarMappable
89
from matplotlib.colors import LinearSegmentedColormap, Normalize
910

11+
from .primitives import Table
1012

11-
def plot_security_report(table):
13+
def plot_security_report(table: Table) -> io.BytesIO:
1214
# Data preprocessing
1315
data = pd.DataFrame(table)
1416

@@ -141,7 +143,7 @@ def plot_security_report(table):
141143
return buf
142144

143145

144-
def generate_identifiers(data):
146+
def generate_identifiers(data : pd.DataFrame) -> List[str]:
145147
data_length = len(data)
146148
alphabet = string.ascii_uppercase
147149
num_letters = len(alphabet)

agentic_security/routes/scan.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime
2+
from typing import Any, Generator
23

34
from fastapi import (
45
APIRouter,
@@ -24,7 +25,7 @@
2425
@router.post("/verify")
2526
async def verify(
2627
info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets)
27-
):
28+
) -> dict[str, int | str | float]:
2829
spec = LLMSpec.from_string(info.spec)
2930
try:
3031
r = await spec.verify()
@@ -42,7 +43,7 @@ async def verify(
4243
)
4344

4445

45-
def streaming_response_generator(scan_parameters: Scan):
46+
def streaming_response_generator(scan_parameters: Scan) -> Generator[str, Any, None]:
4647
request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
4748
set_current_run(request_factory)
4849

@@ -63,15 +64,15 @@ async def scan(
6364
scan_parameters: Scan,
6465
background_tasks: BackgroundTasks,
6566
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
66-
):
67+
) -> StreamingResponse:
6768
scan_parameters.with_secrets(secrets)
6869
return StreamingResponse(
6970
streaming_response_generator(scan_parameters), media_type="application/json"
7071
)
7172

7273

7374
@router.post("/stop")
74-
async def stop_scan():
75+
async def stop_scan() -> dict[str, str]:
7576
get_stop_event().set()
7677
return {"status": "Scan stopped"}
7778

@@ -85,7 +86,7 @@ async def scan_csv(
8586
maxBudget: int = Query(10_000),
8687
enableMultiStepAttack: bool = Query(False),
8788
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
88-
):
89+
) -> StreamingResponse:
8990
# TODO: content dataset to fuzzer
9091
content = await file.read() # noqa
9192
llm_spec = await llmSpec.read()

0 commit comments

Comments
 (0)