diff --git a/agentic_security/core/app.py b/agentic_security/core/app.py index 3938ba8..3cb306e 100644 --- a/agentic_security/core/app.py +++ b/agentic_security/core/app.py @@ -4,10 +4,14 @@ from fastapi import FastAPI from fastapi.responses import ORJSONResponse +from agentic_security.http_spec import LLMSpec + tools_inbox: Queue = Queue() stop_event: Event = Event() current_run: str = {"spec": "", "id": ""} -_secrets = {} +_secrets: dict[str, str] = {} + +current_run: dict[str, int | LLMSpec] = {"spec": "", "id": ""} def create_app() -> FastAPI: @@ -26,29 +30,29 @@ def get_stop_event() -> Event: return stop_event -def get_current_run() -> str: +def get_current_run() -> dict[str, int | LLMSpec]: """Get the current run id.""" return current_run -def set_current_run(spec): +def set_current_run(spec: LLMSpec) -> dict[str, int | LLMSpec]: """Set the current run id.""" current_run["id"] = hash(id(spec)) current_run["spec"] = spec return current_run -def get_secrets(): +def get_secrets() -> dict[str, str]: return _secrets -def set_secrets(secrets): +def set_secrets(secrets: dict[str, str]) -> dict[str, str]: _secrets.update(secrets) expand_secrets(_secrets) return _secrets -def expand_secrets(secrets): +def expand_secrets(secrets: dict[str, str]) -> None: for key in secrets: val = secrets[key] if val.startswith("$"): diff --git a/agentic_security/misc/banner.py b/agentic_security/misc/banner.py index f758c2f..45edba5 100644 --- a/agentic_security/misc/banner.py +++ b/agentic_security/misc/banner.py @@ -1,3 +1,4 @@ + from pyfiglet import Figlet, FontNotFound from termcolor import colored @@ -8,14 +9,14 @@ def generate_banner( - title="Agentic Security", - font="slant", - version="v2.1.0", - tagline="Proactive Threat Detection & Automated Security Protocols", - author="Developed by: [Security Team]", - website="Website: https://github.com/msoedov/agentic_security", - warning="", -): + title: str = "Agentic Security", + font: str = "slant", + version: str = "v2.1.0", + tagline: str = "Proactive Threat Detection & Automated Security Protocols", + author: str = "Developed by: [Security Team]", + website: str = "Website: https://github.com/msoedov/agentic_security", + warning: str | None = "", # Using Optional for warning since it might be None +) -> str: """Generate a visually enhanced banner with dynamic width and borders.""" # Define the text elements diff --git a/agentic_security/report_chart.py b/agentic_security/report_chart.py index c76e576..45387cb 100644 --- a/agentic_security/report_chart.py +++ b/agentic_security/report_chart.py @@ -7,8 +7,10 @@ from matplotlib.cm import ScalarMappable from matplotlib.colors import LinearSegmentedColormap, Normalize +from .primitives import Table -def plot_security_report(table): + +def plot_security_report(table: Table) -> io.BytesIO: # Data preprocessing data = pd.DataFrame(table) @@ -141,7 +143,7 @@ def plot_security_report(table): return buf -def generate_identifiers(data): +def generate_identifiers(data: pd.DataFrame) -> list[str]: data_length = len(data) alphabet = string.ascii_uppercase num_letters = len(alphabet) diff --git a/agentic_security/routes/scan.py b/agentic_security/routes/scan.py index f5e2977..8313702 100644 --- a/agentic_security/routes/scan.py +++ b/agentic_security/routes/scan.py @@ -1,4 +1,6 @@ +from collections.abc import Generator from datetime import datetime +from typing import Any from fastapi import ( APIRouter, @@ -24,7 +26,7 @@ @router.post("/verify") async def verify( info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets) -): +) -> dict[str, int | str | float]: spec = LLMSpec.from_string(info.spec) try: r = await spec.verify() @@ -42,7 +44,7 @@ async def verify( ) -def streaming_response_generator(scan_parameters: Scan): +def streaming_response_generator(scan_parameters: Scan) -> Generator[str, Any, None]: request_factory = LLMSpec.from_string(scan_parameters.llmSpec) set_current_run(request_factory) @@ -63,7 +65,7 @@ async def scan( scan_parameters: Scan, background_tasks: BackgroundTasks, secrets: InMemorySecrets = Depends(get_in_memory_secrets), -): +) -> StreamingResponse: scan_parameters.with_secrets(secrets) return StreamingResponse( streaming_response_generator(scan_parameters), media_type="application/json" @@ -71,7 +73,7 @@ async def scan( @router.post("/stop") -async def stop_scan(): +async def stop_scan() -> dict[str, str]: get_stop_event().set() return {"status": "Scan stopped"} @@ -85,7 +87,7 @@ async def scan_csv( maxBudget: int = Query(10_000), enableMultiStepAttack: bool = Query(False), secrets: InMemorySecrets = Depends(get_in_memory_secrets), -): +) -> StreamingResponse: # TODO: content dataset to fuzzer content = await file.read() # noqa llm_spec = await llmSpec.read()