Skip to content

Commit 03dcf8c

Browse files
committed
feat(Update app structure):
1 parent 65edfe8 commit 03dcf8c

File tree

14 files changed

+357
-284
lines changed

14 files changed

+357
-284
lines changed

agentic_security/app.py

Lines changed: 23 additions & 282 deletions
Original file line numberDiff line numberDiff line change
@@ -1,287 +1,28 @@
1-
import os
2-
import random
3-
from asyncio import Event, Queue
4-
from datetime import datetime
5-
from logging import config
6-
from pathlib import Path
7-
8-
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request, Response
9-
from fastapi.middleware.cors import CORSMiddleware
10-
from fastapi.responses import FileResponse, StreamingResponse
11-
from loguru import logger
12-
from pydantic import BaseModel, Field
13-
from starlette.middleware.base import BaseHTTPMiddleware
14-
15-
from .http_spec import LLMSpec
16-
from .probe_actor import fuzzer
17-
from .probe_actor.refusal import REFUSAL_MARKS
18-
from .probe_data import REGISTRY
19-
from .report_chart import plot_security_report
20-
21-
# Create the FastAPI app instance
22-
app = FastAPI()
23-
origins = [
24-
"*",
25-
]
26-
27-
28-
# Configuration
29-
class Settings:
30-
MAX_BUDGET = 1000
31-
MAX_DATASETS = 10
32-
RATE_LIMIT = "100/minute"
33-
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False)
34-
FEATURE_PROXY = False
35-
36-
37-
settings = Settings()
38-
39-
# Middleware setup
40-
app.add_middleware(
41-
CORSMiddleware,
42-
allow_origins=origins,
43-
allow_credentials=True,
44-
allow_methods=["*"], # Allows all methods
45-
allow_headers=["*"], # Allows all headers
1+
from .core.app import create_app
2+
from .core.logging import setup_logging
3+
from .middleware.cors import setup_cors
4+
from .middleware.logging import LogNon200ResponsesMiddleware
5+
from .routes import (
6+
static_router,
7+
scan_router,
8+
probe_router,
9+
proxy_router,
10+
report_router,
4611
)
4712

48-
tools_inbox = Queue()
49-
# Global stop event for cancelling scans
50-
stop_event = Event() # Added stop_event to cancel the scan
51-
52-
53-
@app.get("/")
54-
async def root():
55-
agentic_security_path = Path(__file__).parent
56-
return FileResponse(f"{agentic_security_path}/static/index.html")
57-
58-
59-
@app.get("/main.js")
60-
async def main_js():
61-
agentic_security_path = Path(__file__).parent
62-
return FileResponse(f"{agentic_security_path}/static/main.js")
63-
64-
65-
@app.get("/telemetry.js")
66-
async def telemetry_js():
67-
agentic_security_path = Path(__file__).parent
68-
if settings.DISABLE_TELEMETRY:
69-
return FileResponse(f"{agentic_security_path}/static/telemetry_disabled.js")
70-
return FileResponse(f"{agentic_security_path}/static/telemetry.js")
71-
72-
73-
@app.get("/favicon.ico")
74-
async def favicon():
75-
agentic_security_path = Path(__file__).parent
76-
return FileResponse(f"{agentic_security_path}/static/favicon.ico")
77-
78-
79-
class LLMInfo(BaseModel):
80-
spec: str
81-
82-
83-
@app.post("/verify")
84-
async def verify(info: LLMInfo):
85-
86-
spec = LLMSpec.from_string(info.spec)
87-
r = await spec.probe("test")
88-
if r.status_code >= 400:
89-
raise HTTPException(status_code=r.status_code, detail=r.text)
90-
return dict(
91-
status_code=r.status_code,
92-
body=r.text,
93-
elapsed=r.elapsed.total_seconds(),
94-
timestamp=datetime.now().isoformat(),
95-
)
96-
97-
98-
class Scan(BaseModel):
99-
llmSpec: str
100-
maxBudget: int
101-
datasets: list[dict] = []
102-
optimize: bool = False
103-
104-
105-
class ScanResult(BaseModel):
106-
module: str
107-
tokens: int
108-
cost: float
109-
progress: float
110-
failureRate: float = 0.0
111-
112-
113-
def streaming_response_generator(scan_parameters: Scan):
114-
# The generator function for StreamingResponse
115-
request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
116-
117-
async def _gen():
118-
async for scan_result in fuzzer.perform_scan(
119-
request_factory=request_factory,
120-
max_budget=scan_parameters.maxBudget,
121-
datasets=scan_parameters.datasets,
122-
tools_inbox=tools_inbox,
123-
optimize=scan_parameters.optimize,
124-
stop_event=stop_event, # Pass the stop_event to the generator
125-
):
126-
yield scan_result + "\n" # Adding a newline for separation
127-
128-
return _gen()
129-
13+
# Create the FastAPI app
14+
app = create_app()
13015

131-
@app.post("/scan")
132-
async def scan(scan_parameters: Scan, background_tasks: BackgroundTasks):
133-
134-
# Initiates streaming of scan results
135-
return StreamingResponse(
136-
streaming_response_generator(scan_parameters), media_type="application/json"
137-
)
138-
139-
140-
class Probe(BaseModel):
141-
prompt: str
142-
143-
144-
@app.post("/v1/self-probe")
145-
def self_probe(probe: Probe):
146-
refuse = random.random() < 0.2
147-
message = random.choice(REFUSAL_MARKS) if refuse else "This is a test!"
148-
message = probe.prompt + " " + message
149-
return {
150-
"id": "chatcmpl-abc123",
151-
"object": "chat.completion",
152-
"created": 1677858242,
153-
"model": "gpt-3.5-turbo-0613",
154-
"usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20},
155-
"choices": [
156-
{
157-
"message": {"role": "assistant", "content": message},
158-
"logprobs": None,
159-
"finish_reason": "stop",
160-
"index": 0,
161-
}
162-
],
163-
}
164-
165-
166-
@app.get("/v1/data-config")
167-
async def data_config():
168-
return [m for m in REGISTRY]
169-
170-
171-
@app.get("/failures")
172-
async def failures_csv():
173-
if not Path("failures.csv").exists():
174-
return {"error": "No failures found"}
175-
return FileResponse("failures.csv")
176-
177-
178-
class Table(BaseModel):
179-
table: list[dict]
180-
181-
182-
@app.post("/plot.jpeg", response_class=Response)
183-
async def get_plot(table: Table):
184-
buf = plot_security_report(table.table)
185-
return StreamingResponse(buf, media_type="image/jpeg")
186-
187-
188-
class Message(BaseModel):
189-
role: str
190-
content: str
191-
192-
193-
class CompletionRequest(BaseModel):
194-
"""Model for completion requests."""
195-
196-
model: str
197-
messages: list[Message]
198-
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
199-
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
200-
n: int = Field(default=1, ge=1, le=10)
201-
stop: list[str] | None = None
202-
max_tokens: int = Field(default=100, ge=1, le=4096)
203-
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
204-
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
205-
206-
207-
# OpenAI proxy endpoint
208-
@app.post("/proxy/chat/completions")
209-
async def proxy_completions(request: CompletionRequest):
210-
refuse = random.random() < 0.2
211-
message = random.choice(REFUSAL_MARKS) if refuse else "This is a test!"
212-
prompt_content = " ".join(
213-
[msg.content for msg in request.messages if msg.role == "user"]
214-
)
215-
message = prompt_content + " " + message
216-
ready = Event()
217-
ref = dict(message=message, reply="", ready=ready)
218-
tools_inbox.put_nowait(ref)
219-
if settings.FEATURE_PROXY:
220-
# Proxy to agent
221-
await ready.wait()
222-
reply = ref["reply"]
223-
return reply
224-
# Simulate a completion response
225-
return {
226-
"id": "chatcmpl-abc123",
227-
"object": "chat.completion",
228-
"created": 1677858242,
229-
"model": "gpt-3.5-turbo-0613",
230-
"usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20},
231-
"choices": [
232-
{
233-
"message": {"role": "assistant", "content": message},
234-
"logprobs": None,
235-
"finish_reason": "stop",
236-
"index": 0,
237-
}
238-
],
239-
}
240-
241-
242-
config.dictConfig(
243-
{
244-
"version": 1,
245-
"disable_existing_loggers": True,
246-
"handlers": {
247-
"console": {
248-
"class": "logging.StreamHandler",
249-
},
250-
},
251-
"root": {
252-
"handlers": ["console"],
253-
"level": "INFO",
254-
},
255-
"loggers": {
256-
"uvicorn.access": {
257-
"level": "ERROR", # Set higher log level to suppress info logs globally
258-
"handlers": ["console"],
259-
"propagate": False,
260-
}
261-
},
262-
}
263-
)
264-
265-
266-
@app.post("/stop")
267-
async def stop_scan():
268-
stop_event.set() # Set the stop event to cancel the scan
269-
return {"status": "Scan stopped"}
270-
271-
272-
class LogNon200ResponsesMiddleware(BaseHTTPMiddleware):
273-
async def dispatch(self, request: Request, call_next):
274-
try:
275-
response = await call_next(request)
276-
except Exception as e:
277-
logger.exception("Yikes")
278-
raise e
279-
if response.status_code != 200:
280-
logger.error(
281-
f"{request.method} {request.url} - Status code: {response.status_code}"
282-
)
283-
return response
16+
# Setup middleware
17+
setup_cors(app)
18+
app.add_middleware(LogNon200ResponsesMiddleware)
28419

20+
# Setup logging
21+
setup_logging()
28522

286-
# Add middleware to the application
287-
app.add_middleware(LogNon200ResponsesMiddleware)
23+
# Register routers
24+
app.include_router(static_router)
25+
app.include_router(scan_router)
26+
app.include_router(probe_router)
27+
app.include_router(proxy_router)
28+
app.include_router(report_router)

agentic_security/core/app.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from asyncio import Event, Queue
2+
from fastapi import FastAPI
3+
4+
tools_inbox: Queue = Queue()
5+
stop_event: Event = Event()
6+
7+
8+
def create_app() -> FastAPI:
9+
"""Create and configure the FastAPI application."""
10+
app = FastAPI()
11+
return app
12+
13+
14+
def get_tools_inbox() -> Queue:
15+
"""Get the global tools inbox queue."""
16+
return tools_inbox
17+
18+
19+
def get_stop_event() -> Event:
20+
"""Get the global stop event."""
21+
return stop_event

agentic_security/core/logging.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from logging import config
2+
3+
4+
def setup_logging():
5+
config.dictConfig(
6+
{
7+
"version": 1,
8+
"disable_existing_loggers": True,
9+
"handlers": {
10+
"console": {
11+
"class": "logging.StreamHandler",
12+
},
13+
},
14+
"root": {
15+
"handlers": ["console"],
16+
"level": "INFO",
17+
},
18+
"loggers": {
19+
"uvicorn.access": {
20+
"level": "ERROR", # Set higher log level to suppress info logs globally
21+
"handlers": ["console"],
22+
"propagate": False,
23+
}
24+
},
25+
}
26+
)

agentic_security/lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import tqdm.asyncio
66
from tabulate import tabulate
77

8-
from agentic_security.app import Scan, streaming_response_generator
8+
from agentic_security.models.schemas import Scan
9+
from agentic_security.routes.scan import streaming_response_generator
910
from agentic_security.probe_data import REGISTRY
1011

1112
RESET = colorama.Style.RESET_ALL
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from fastapi import FastAPI
2+
from fastapi.middleware.cors import CORSMiddleware
3+
4+
5+
def setup_cors(app: FastAPI):
6+
origins = ["*"]
7+
8+
app.add_middleware(
9+
CORSMiddleware,
10+
allow_origins=origins,
11+
allow_credentials=True,
12+
allow_methods=["*"], # Allows all methods
13+
allow_headers=["*"], # Allows all headers
14+
)

0 commit comments

Comments
 (0)