Skip to content

Commit 48125bd

Browse files
committed
feat(add executor):
1 parent 5285fdd commit 48125bd

File tree

14 files changed

+1710
-2
lines changed

14 files changed

+1710
-2
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ agentic_security.toml
2020
/venv
2121
*.csv
2222
agentic_security/agents/operator_agno.py
23+
.claude/
24+
plan.md
25+
auto_loop.sh
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Advanced concurrent execution package for security scanning."""
2+
3+
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
4+
from agentic_security.executor.circuit_breaker import CircuitBreaker
5+
from agentic_security.executor.concurrent import ConcurrentExecutor, ExecutorMetrics
6+
7+
__all__ = [
8+
"TokenBucketRateLimiter",
9+
"CircuitBreaker",
10+
"ConcurrentExecutor",
11+
"ExecutorMetrics",
12+
]
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Circuit breaker pattern for fault tolerance."""
2+
3+
import time
4+
from typing import Literal
5+
6+
7+
CircuitState = Literal["closed", "open", "half_open"]
8+
9+
10+
class CircuitBreaker:
11+
"""Circuit breaker to prevent cascading failures.
12+
13+
Implements the circuit breaker pattern with three states:
14+
- closed: Normal operation, requests pass through
15+
- open: Failure threshold exceeded, requests fail fast
16+
- half_open: Recovery attempt, limited requests allowed
17+
18+
Example:
19+
>>> breaker = CircuitBreaker(failure_threshold=0.5, recovery_timeout=30)
20+
>>> if breaker.is_open():
21+
... raise Exception("Circuit breaker is open")
22+
>>> try:
23+
... result = make_request()
24+
... breaker.record_success()
25+
>>> except Exception:
26+
... breaker.record_failure()
27+
"""
28+
29+
def __init__(self, failure_threshold: float = 0.5, recovery_timeout: int = 30):
30+
"""Initialize circuit breaker.
31+
32+
Args:
33+
failure_threshold: Failure rate (0.0-1.0) that triggers open state
34+
recovery_timeout: Seconds to wait before attempting recovery
35+
"""
36+
self.failure_threshold = failure_threshold
37+
self.recovery_timeout = recovery_timeout
38+
self.failures = 0
39+
self.successes = 0
40+
self.state: CircuitState = "closed"
41+
self.last_failure_time: float | None = None
42+
43+
def record_success(self):
44+
"""Record a successful request."""
45+
self.successes += 1
46+
47+
# If in half_open state and we have enough successes, close the circuit
48+
if self.state == "half_open" and self.successes >= 3:
49+
self.state = "closed"
50+
self.failures = 0
51+
self.successes = 0
52+
53+
def record_failure(self):
54+
"""Record a failed request."""
55+
self.failures += 1
56+
self.last_failure_time = time.monotonic()
57+
58+
total = self.failures + self.successes
59+
60+
# Need minimum sample size before opening circuit
61+
if total >= 10:
62+
failure_rate = self.failures / total
63+
if failure_rate >= self.failure_threshold:
64+
self.state = "open"
65+
66+
def is_open(self) -> bool:
67+
"""Check if circuit breaker is open.
68+
69+
Returns:
70+
bool: True if circuit is open and requests should be blocked
71+
"""
72+
if self.state == "open":
73+
# Check if we should attempt recovery
74+
if self.last_failure_time is not None:
75+
if time.monotonic() - self.last_failure_time > self.recovery_timeout:
76+
self.state = "half_open"
77+
# Reset counters for half-open state
78+
self.failures = 0
79+
self.successes = 0
80+
return False
81+
return True
82+
83+
return False
84+
85+
def get_state(self) -> CircuitState:
86+
"""Get current circuit breaker state.
87+
88+
Returns:
89+
CircuitState: Current state (closed, open, or half_open)
90+
"""
91+
return self.state
92+
93+
def get_failure_rate(self) -> float:
94+
"""Get current failure rate.
95+
96+
Returns:
97+
float: Failure rate (0.0-1.0), or 0.0 if no requests recorded
98+
"""
99+
total = self.failures + self.successes
100+
if total == 0:
101+
return 0.0
102+
return self.failures / total
103+
104+
def reset(self):
105+
"""Reset circuit breaker to initial state."""
106+
self.failures = 0
107+
self.successes = 0
108+
self.state = "closed"
109+
self.last_failure_time = None
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""Concurrent executor with rate limiting and circuit breaking."""
2+
3+
import asyncio
4+
import time
5+
from typing import Any
6+
7+
from agentic_security.executor.rate_limiter import TokenBucketRateLimiter
8+
from agentic_security.executor.circuit_breaker import CircuitBreaker
9+
from agentic_security.logutils import logger
10+
from agentic_security.probe_actor.state import FuzzerState
11+
12+
13+
class ExecutorMetrics:
14+
"""Track executor performance metrics."""
15+
16+
def __init__(self):
17+
"""Initialize metrics tracking."""
18+
self.successful_requests = 0
19+
self.failed_requests = 0
20+
self.total_latency = 0.0
21+
self.latencies: list[float] = []
22+
23+
def record_success(self, latency: float):
24+
"""Record a successful request.
25+
26+
Args:
27+
latency: Request latency in seconds
28+
"""
29+
self.successful_requests += 1
30+
self.total_latency += latency
31+
self.latencies.append(latency)
32+
33+
def record_failure(self):
34+
"""Record a failed request."""
35+
self.failed_requests += 1
36+
37+
def get_stats(self) -> dict[str, Any]:
38+
"""Get current statistics.
39+
40+
Returns:
41+
dict: Statistics including total requests, success rate, latency metrics
42+
"""
43+
total_requests = self.successful_requests + self.failed_requests
44+
45+
if total_requests == 0:
46+
return {
47+
"total_requests": 0,
48+
"success_rate": 0.0,
49+
"avg_latency_ms": 0.0,
50+
"p95_latency_ms": 0.0,
51+
}
52+
53+
success_rate = self.successful_requests / total_requests
54+
avg_latency_ms = (
55+
(self.total_latency / self.successful_requests * 1000)
56+
if self.successful_requests > 0
57+
else 0.0
58+
)
59+
60+
# Calculate p95 latency
61+
if self.latencies:
62+
sorted_latencies = sorted(self.latencies)
63+
p95_index = int(len(sorted_latencies) * 0.95)
64+
p95_latency_ms = sorted_latencies[p95_index] * 1000 if p95_index < len(sorted_latencies) else 0.0
65+
else:
66+
p95_latency_ms = 0.0
67+
68+
return {
69+
"total_requests": total_requests,
70+
"successful_requests": self.successful_requests,
71+
"failed_requests": self.failed_requests,
72+
"success_rate": success_rate,
73+
"avg_latency_ms": avg_latency_ms,
74+
"p95_latency_ms": p95_latency_ms,
75+
}
76+
77+
78+
class ConcurrentExecutor:
79+
"""Enhanced concurrent executor with rate limiting and circuit breaking.
80+
81+
Provides advanced concurrency control for security scanning with:
82+
- Token bucket rate limiting
83+
- Circuit breaker for fault tolerance
84+
- Metrics collection
85+
- Semaphore-based concurrency limits
86+
87+
Example:
88+
>>> executor = ConcurrentExecutor(max_concurrent=20, rate_limit=10, burst=5)
89+
>>> tokens, failures = await executor.execute_batch(
90+
... request_factory, prompts, "module_name", fuzzer_state
91+
... )
92+
>>> print(executor.metrics.get_stats())
93+
"""
94+
95+
def __init__(
96+
self,
97+
max_concurrent: int = 50,
98+
rate_limit: float = 100,
99+
burst: int = 20,
100+
failure_threshold: float = 0.5,
101+
recovery_timeout: int = 30,
102+
):
103+
"""Initialize concurrent executor.
104+
105+
Args:
106+
max_concurrent: Maximum number of concurrent requests
107+
rate_limit: Requests per second limit
108+
burst: Maximum burst size for rate limiter
109+
failure_threshold: Failure rate that triggers circuit breaker
110+
recovery_timeout: Seconds before attempting circuit recovery
111+
"""
112+
self.semaphore = asyncio.Semaphore(max_concurrent)
113+
self.rate_limiter = TokenBucketRateLimiter(rate_limit, burst)
114+
self.circuit_breaker = CircuitBreaker(failure_threshold, recovery_timeout)
115+
self.metrics = ExecutorMetrics()
116+
117+
logger.info(
118+
f"ConcurrentExecutor initialized: max_concurrent={max_concurrent}, "
119+
f"rate_limit={rate_limit}/s, burst={burst}"
120+
)
121+
122+
async def execute_batch(
123+
self,
124+
request_factory,
125+
prompts: list[str],
126+
module_name: str,
127+
fuzzer_state: FuzzerState,
128+
) -> tuple[int, int]:
129+
"""Execute a batch of prompts with rate limiting and circuit breaking.
130+
131+
This is compatible with the existing process_prompt_batch signature.
132+
133+
Args:
134+
request_factory: Request factory with fn() method
135+
prompts: List of prompts to process
136+
module_name: Name of the module being scanned
137+
fuzzer_state: State tracking object
138+
139+
Returns:
140+
tuple[int, int]: (total_tokens, failures)
141+
"""
142+
tasks = [
143+
self._execute_single(request_factory, prompt, module_name, fuzzer_state)
144+
for prompt in prompts
145+
]
146+
147+
results = await asyncio.gather(*tasks, return_exceptions=True)
148+
149+
# Aggregate results
150+
total_tokens = 0
151+
failures = 0
152+
153+
for result in results:
154+
if isinstance(result, Exception):
155+
failures += 1
156+
logger.error(f"Task failed with exception: {result}")
157+
else:
158+
tokens, refused = result
159+
total_tokens += tokens
160+
if refused:
161+
failures += 1
162+
163+
return total_tokens, failures
164+
165+
async def _execute_single(
166+
self,
167+
request_factory,
168+
prompt: str,
169+
module_name: str,
170+
fuzzer_state: FuzzerState,
171+
) -> tuple[int, bool]:
172+
"""Execute a single prompt with rate limiting and circuit breaking.
173+
174+
Args:
175+
request_factory: Request factory with fn() method
176+
prompt: Prompt to process
177+
module_name: Name of the module being scanned
178+
fuzzer_state: State tracking object
179+
180+
Returns:
181+
tuple[int, bool]: (tokens, refused)
182+
183+
Raises:
184+
Exception: If circuit breaker is open
185+
"""
186+
# Rate limiting
187+
await self.rate_limiter.acquire()
188+
189+
# Circuit breaker check
190+
if self.circuit_breaker.is_open():
191+
self.metrics.record_failure()
192+
raise Exception("Circuit breaker is open - too many failures")
193+
194+
# Concurrency control
195+
async with self.semaphore:
196+
start_time = time.monotonic()
197+
198+
try:
199+
# Import here to avoid circular dependency
200+
from agentic_security.probe_actor.fuzzer import process_prompt
201+
202+
tokens = 0 # Initial token count for this prompt
203+
result = await process_prompt(
204+
request_factory, prompt, tokens, module_name, fuzzer_state
205+
)
206+
207+
# Record success
208+
self.circuit_breaker.record_success()
209+
latency = time.monotonic() - start_time
210+
self.metrics.record_success(latency)
211+
212+
return result
213+
214+
except Exception as e:
215+
# Record failure
216+
self.circuit_breaker.record_failure()
217+
self.metrics.record_failure()
218+
logger.error(f"Error executing prompt: {e}")
219+
raise
220+
221+
def get_metrics(self) -> dict[str, Any]:
222+
"""Get current executor metrics.
223+
224+
Returns:
225+
dict: Metrics including request stats, latency, and circuit breaker state
226+
"""
227+
stats = self.metrics.get_stats()
228+
stats["circuit_breaker_state"] = self.circuit_breaker.get_state()
229+
stats["circuit_breaker_failure_rate"] = self.circuit_breaker.get_failure_rate()
230+
stats["available_tokens"] = self.rate_limiter.get_available_tokens()
231+
232+
return stats

0 commit comments

Comments
 (0)