-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathsecurity_manager.py
More file actions
439 lines (351 loc) · 15.5 KB
/
security_manager.py
File metadata and controls
439 lines (351 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
import json
from datetime import datetime
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Optional
import jwt
import redis.asyncio as redis
import structlog
from cryptography.fernet import Fernet
logger = structlog.get_logger()
class SecurityManager:
"""Implements comprehensive security controls for MCP server"""
def __init__(self, config: Dict):
self.encryption_key = self._get_or_generate_key(config.get("encryption_key"))
self.cipher = Fernet(self.encryption_key)
self.jwt_secret = config["jwt_secret"]
self.audit_logger = AuditLogger(config.get("audit_config", {}))
self.rate_limiter = RateLimiter(config.get("redis", {}))
self.current_user = None
def authenticate_request(self, token: str) -> Dict:
"""Validates JWT tokens with security checks"""
try:
payload = jwt.decode(
token, self.jwt_secret, algorithms=["RS256", "HS256"], options={"require": ["exp", "iat", "sub", "aud"]}
)
# Validate token expiration
if datetime.utcnow().timestamp() > payload["exp"]:
raise ValueError("Token expired")
# Validate issuer and audience
if not self._validate_token_claims(payload):
raise ValueError("Invalid token claims")
# Check token binding (if implemented)
if "token_binding" in payload:
if not self._validate_token_binding(payload, token):
raise ValueError("Token binding mismatch")
self.current_user = payload
return payload
except jwt.ExpiredSignatureError:
raise ValueError("Token expired")
except jwt.InvalidTokenError as e:
raise ValueError(f"Invalid token: {str(e)}")
def encrypt_sensitive_data(self, data: str) -> str:
"""Encrypts sensitive data at rest"""
if not isinstance(data, str):
data = json.dumps(data)
return self.cipher.encrypt(data.encode()).decode()
def decrypt_sensitive_data(self, encrypted_data: str) -> str:
"""Decrypts sensitive data"""
try:
return self.cipher.decrypt(encrypted_data.encode()).decode()
except Exception as e:
logger.error("Decryption failed", error=str(e))
raise ValueError("Failed to decrypt data")
async def audit_log(self, action: str, user: str, details: Dict, severity: str = "INFO"):
"""Comprehensive audit logging for compliance"""
log_entry = {
"timestamp": datetime.utcnow().isoformat(),
"user": user,
"action": action,
"details": details,
"session_id": self._get_session_id(),
"ip_address": self._get_client_ip(),
"user_agent": self._get_user_agent(),
"severity": severity,
"source": "threat_hunting_mcp",
}
# Log to structured logger
await self.audit_logger.log(log_entry)
# Also log security-relevant events to security log
if severity in ["WARNING", "ERROR", "CRITICAL"]:
await self.audit_logger.security_log(log_entry)
def rate_limit(self, key: str, max_requests: int = 100, window_seconds: int = 3600):
"""Implements rate limiting with sliding window"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
user_key = f"{self.get_current_user()}:{key}"
if not await self.rate_limiter.check_rate_limit(user_key, max_requests, window_seconds):
await self.audit_log(
"rate_limit_exceeded", self.get_current_user(
), {"key": key, "limit": max_requests}, "WARNING"
)
raise ValueError("Rate limit exceeded")
return await func(*args, **kwargs)
return wrapper
return decorator
def require_auth(self, func):
"""Decorator to require authentication"""
@wraps(func)
async def wrapper(*args, **kwargs):
if not self.current_user:
raise ValueError("Authentication required")
return await func(*args, **kwargs)
return wrapper
def require_permission(self, permission: str):
"""Decorator to require specific permission"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if not self.has_permission(permission):
await self.audit_log(
"permission_denied",
self.get_current_user(),
{"permission": permission, "function": func.__name__},
"WARNING",
)
raise ValueError(f"Permission '{permission}' required")
return await func(*args, **kwargs)
return wrapper
return decorator
def has_permission(self, permission: str) -> bool:
"""Checks if current user has specific permission"""
if not self.current_user:
return False
user_permissions = self.current_user.get("permissions", [])
user_roles = self.current_user.get("roles", [])
# Direct permission check
if permission in user_permissions:
return True
# Role-based permission check
role_permissions = self._get_role_permissions()
for role in user_roles:
if permission in role_permissions.get(role, []):
return True
return False
def get_current_user(self) -> str:
"""Gets current authenticated user"""
if self.current_user:
return self.current_user.get("sub", "unknown")
return "anonymous"
def sanitize_input(self, input_data: Any) -> Any:
"""Sanitizes input data to prevent injection attacks"""
if isinstance(input_data, str):
# Remove potentially dangerous characters for SPL injection
dangerous_chars = [";", "|eval", "|delete", "|drop", "rm -rf"]
for char in dangerous_chars:
if char in input_data.lower():
logger.warning(
"Potentially dangerous input detected",
input=input_data,
char=char)
input_data = input_data.replace(char, "")
# Limit input length
if len(input_data) > 10000:
input_data = input_data[:10000]
elif isinstance(input_data, dict):
return {k: self.sanitize_input(v) for k, v in input_data.items()}
elif isinstance(input_data, list):
return [self.sanitize_input(item) for item in input_data]
return input_data
def validate_splunk_query(self, query: str) -> bool:
"""Validates Splunk queries for security"""
dangerous_commands = [
"delete",
"drop",
"outputcsv",
"outputlookup",
"script",
"sendemail",
"rest"]
query_lower = query.lower()
for cmd in dangerous_commands:
if cmd in query_lower:
logger.warning("Dangerous Splunk command detected", query=query, command=cmd)
return False
# Check for excessive wildcards
if query.count("*") > 10:
logger.warning("Excessive wildcards in query", query=query)
return False
return True
def _get_or_generate_key(self, provided_key: Optional[str]) -> bytes:
"""Gets or generates encryption key"""
if provided_key:
# Fernet expects a base64-encoded key, so just encode the string to
# bytes
return provided_key.encode()
else:
# Generate new key (in production, store securely)
return Fernet.generate_key()
def _validate_token_claims(self, payload: Dict) -> bool:
"""Validates JWT token claims"""
required_claims = ["sub", "iat", "exp"]
for claim in required_claims:
if claim not in payload:
return False
# Validate audience (if configured)
expected_audience = "threat_hunting_mcp"
if "aud" in payload and payload["aud"] != expected_audience:
return False
return True
def _validate_token_binding(self, payload: Dict, token: str) -> bool:
"""Validates token binding to prevent token theft"""
# This would implement token binding validation
# For now, return True (implement based on your requirements)
return True
def _get_session_id(self) -> str:
"""Gets current session ID"""
if self.current_user:
return self.current_user.get("session_id", "unknown")
return "no_session"
def _get_client_ip(self) -> str:
"""Gets client IP address"""
# This would be implemented based on your transport mechanism
return "unknown"
def _get_user_agent(self) -> str:
"""Gets user agent string"""
# This would be implemented based on your transport mechanism
return "mcp_client"
def _get_role_permissions(self) -> Dict[str, List[str]]:
"""Returns role-based permissions mapping"""
return {
"admin": ["hunt:create", "hunt:execute", "hunt:delete", "query:splunk", "config:modify", "user:manage"],
"analyst": ["hunt:create", "hunt:execute", "query:splunk"],
"viewer": ["hunt:view", "query:view"],
}
class AuditLogger:
"""Handles audit logging with multiple outputs"""
def __init__(self, config: Dict):
self.log_file = config.get("log_file", "/tmp/threat_hunting_mcp_audit.log")
self.security_log_file = config.get(
"security_log_file",
"/tmp/threat_hunting_mcp_security.log")
self.siem_enabled = config.get("siem_enabled", False)
self.siem_endpoint = config.get("siem_endpoint")
# Ensure log directories exist
Path(self.log_file).parent.mkdir(parents=True, exist_ok=True)
Path(self.security_log_file).parent.mkdir(parents=True, exist_ok=True)
async def log(self, log_entry: Dict):
"""Logs audit entry to file and optionally SIEM"""
# Log to file
await self._log_to_file(log_entry, self.log_file)
# Log to SIEM if enabled
if self.siem_enabled and self.siem_endpoint:
await self._log_to_siem(log_entry)
async def security_log(self, log_entry: Dict):
"""Logs security-relevant entries to dedicated security log"""
await self._log_to_file(log_entry, self.security_log_file)
async def _log_to_file(self, log_entry: Dict, file_path: str):
"""Logs entry to file"""
try:
with open(file_path, "a") as f:
f.write(json.dumps(log_entry) + "\n")
except Exception as e:
logger.error("Failed to write audit log", error=str(e))
async def _log_to_siem(self, log_entry: Dict):
"""Sends log entry to SIEM"""
# Implementation would depend on your SIEM
# This is a placeholder
logger.info("Would send to SIEM", entry=log_entry)
class RateLimiter:
"""Redis-based rate limiting"""
def __init__(self, redis_config: Dict):
self.redis_pool = None
if redis_config:
self.redis_pool = redis.ConnectionPool(
host=redis_config.get("host", "localhost"),
port=redis_config.get("port", 6379),
db=redis_config.get("db", 0),
password=redis_config.get("password"),
)
async def check_rate_limit(self, key: str, max_requests: int, window_seconds: int) -> bool:
"""Checks if request is within rate limit using sliding window"""
if not self.redis_pool:
# If no Redis, allow all requests (not recommended for production)
return True
try:
redis_conn = redis.Redis(connection_pool=self.redis_pool)
now = datetime.utcnow().timestamp()
window_start = now - window_seconds
# Use sliding window log approach
pipe = redis_conn.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, 0, window_start)
# Count current requests
pipe.zcard(key)
# Add current request
pipe.zadd(key, {str(now): now})
# Set expiration
pipe.expire(key, window_seconds)
results = await pipe.execute()
current_requests = results[1]
return current_requests < max_requests
except Exception as e:
logger.error("Rate limiting failed", error=str(e))
# Fail open - allow request if rate limiting fails
return True
class CacheManager:
"""Manages caching for threat intelligence data"""
def __init__(self, redis_config: Dict):
self.redis_pool = None
if redis_config:
self.redis_pool = redis.ConnectionPool(
host=redis_config.get("host", "localhost"),
port=redis_config.get("port", 6379),
db=redis_config.get("db", 0),
password=redis_config.get("password"),
)
self.ttl_config = {
"mitre_techniques": 86400, # 24 hours
"threat_actors": 14400, # 4 hours
"ioc_lookups": 3600, # 1 hour
"hunt_results": 7200, # 2 hours
"static_playbooks": 604800, # 7 days
"static_content": 86400, # 24 hours - for MITRE matrix, methodologies, etc.
}
async def get_or_compute(self, key: str, compute_func, ttl_type: str):
"""Cache-aside pattern with automatic computation"""
if not self.redis_pool:
# No caching, compute directly
return await compute_func()
try:
redis_conn = redis.Redis(connection_pool=self.redis_pool)
# Try cache first
cached = await redis_conn.get(key)
if cached:
return json.loads(cached)
# Compute if not cached
result = await compute_func()
# Store with appropriate TTL
ttl = self.ttl_config.get(ttl_type, 3600)
await redis_conn.setex(key, ttl, json.dumps(result))
return result
except Exception as e:
logger.error("Cache operation failed", error=str(e))
# Fallback to direct computation
return await compute_func()
async def invalidate_pattern(self, pattern: str):
"""Invalidates cache entries matching pattern"""
if not self.redis_pool:
return
try:
redis_conn = redis.Redis(connection_pool=self.redis_pool)
cursor = 0
while True:
cursor, keys = await redis_conn.scan(cursor, match=pattern, count=100)
if keys:
await redis_conn.delete(*keys)
if cursor == 0:
break
except Exception as e:
logger.error("Cache invalidation failed", error=str(e))
async def set(self, key: str, value: Any, ttl_type: str):
"""Sets a cache value"""
if not self.redis_pool:
return
try:
redis_conn = redis.Redis(connection_pool=self.redis_pool)
ttl = self.ttl_config.get(ttl_type, 3600)
await redis_conn.setex(key, ttl, json.dumps(value))
except Exception as e:
logger.error("Cache set failed", error=str(e))