Skip to content

Commit c675d88

Browse files
committed
feat: Enhance security and input validation with environment-based credentials, CSRF protection, and structured logging
1 parent e37ca7f commit c675d88

File tree

12 files changed

+207
-136
lines changed

12 files changed

+207
-136
lines changed

backend/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ WORKDIR /app
44

55
# Install dependencies with specific order to handle dependencies correctly
66
COPY requirements.txt .
7-
RUN pip install --no-cache-dir werkzeug==2.2.3 && \
8-
pip install --no-cache-dir -r requirements.txt
7+
RUN pip install --no-cache-dir -r requirements.txt
98

109
# Install curl for healthcheck (Docker CLI removed for security)
1110
RUN apt-get update && \

backend/app/app.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@
2222
from contextlib import contextmanager
2323
import signal
2424
import sys
25+
from pythonjsonlogger import jsonlogger
26+
from marshmallow import Schema, fields, validate, ValidationError
27+
28+
class IPSchema(Schema):
29+
ip = fields.String(required=True)
30+
description = fields.String(missing='')
31+
32+
class DomainSchema(Schema):
33+
domain = fields.String(required=True, validate=validate.Regexp(r'^(\*\.)?(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$'))
34+
description = fields.String(missing='')
2535

2636
# Rate limiting setup
2737
auth_attempts = defaultdict(list)
@@ -53,15 +63,13 @@
5363
os.makedirs('logs', exist_ok=True)
5464
log_path = 'logs/backend.log'
5565

56-
logging.basicConfig(
57-
level=logging.INFO,
58-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
59-
handlers=[
60-
logging.FileHandler(log_path),
61-
logging.StreamHandler()
62-
]
63-
)
6466
logger = logging.getLogger(__name__)
67+
logHandler = logging.FileHandler(log_path)
68+
formatter = jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s')
69+
logHandler.setFormatter(formatter)
70+
logger.addHandler(logHandler)
71+
logger.setLevel(logging.INFO)
72+
logger.addHandler(logging.StreamHandler())
6573

6674
# Database configuration
6775
DATABASE_PATH = '/data/secure_proxy.db'
@@ -146,19 +154,31 @@ def init_db():
146154
)
147155
''')
148156

149-
# Insert default admin user if not exists
150-
cursor.execute("SELECT COUNT(*) FROM users WHERE username = 'admin'")
151-
if cursor.fetchone()[0] == 0:
152-
# Use a known password hash for 'admin' to ensure UI can authenticate
153-
admin_password_hash = generate_password_hash('admin')
154-
logger.info(f"Creating default admin user with password hash")
155-
cursor.execute("INSERT INTO users (username, password) VALUES (?, ?)",
156-
('admin', admin_password_hash))
157+
# Handle admin user creation/update based on environment variables
158+
env_username = os.environ.get('BASIC_AUTH_USERNAME')
159+
env_password = os.environ.get('BASIC_AUTH_PASSWORD')
160+
161+
if env_username and env_password:
162+
# Check if user exists
163+
cursor.execute("SELECT COUNT(*) FROM users WHERE username = ?", (env_username,))
164+
if cursor.fetchone()[0] == 0:
165+
logger.info(f"Creating admin user '{env_username}' from environment variables")
166+
cursor.execute("INSERT INTO users (username, password) VALUES (?, ?)",
167+
(env_username, generate_password_hash(env_password)))
168+
else:
169+
logger.info(f"Updating admin user '{env_username}' from environment variables")
170+
cursor.execute("UPDATE users SET password = ? WHERE username = ?",
171+
(generate_password_hash(env_password), env_username))
157172
else:
158-
# For existing installations, ensure the admin password hash is correct
159-
cursor.execute("UPDATE users SET password = ? WHERE username = ?",
160-
(generate_password_hash('admin'), 'admin'))
161-
logger.info("Updated admin password hash to ensure authentication works")
173+
# Check if any user exists
174+
cursor.execute("SELECT COUNT(*) FROM users")
175+
if cursor.fetchone()[0] == 0:
176+
# No users and no env vars - generate random credentials
177+
gen_username = 'admin'
178+
gen_password = secrets.token_urlsafe(16)
179+
logger.warning(f"No credentials provided. Created default user '{gen_username}' with password: {gen_password}")
180+
cursor.execute("INSERT INTO users (username, password) VALUES (?, ?)",
181+
(gen_username, generate_password_hash(gen_password)))
162182

163183
# Insert default settings if not exists
164184
default_settings = [
@@ -434,13 +454,14 @@ def get_ip_blacklist():
434454
@auth.login_required
435455
def add_ip_to_blacklist():
436456
"""Add an IP to the blacklist"""
437-
data = request.get_json()
438-
if not data or 'ip' not in data:
439-
return jsonify({"status": "error", "message": "No IP provided"}), 400
457+
try:
458+
data = IPSchema().load(request.get_json())
459+
except ValidationError as err:
460+
return jsonify({"status": "error", "message": err.messages}), 400
440461

441462
# Validate IP address format
442463
ip = data['ip'].strip()
443-
description = data.get('description', '')
464+
description = data['description']
444465

445466
# Validate CIDR notation or single IP address
446467
try:
@@ -492,18 +513,13 @@ def get_domain_blacklist():
492513
@auth.login_required
493514
def add_domain_to_blacklist():
494515
"""Add a domain to the blacklist"""
495-
data = request.get_json()
496-
if not data or 'domain' not in data:
497-
return jsonify({"status": "error", "message": "No domain provided"}), 400
516+
try:
517+
data = DomainSchema().load(request.get_json())
518+
except ValidationError as err:
519+
return jsonify({"status": "error", "message": err.messages}), 400
498520

499521
domain = data['domain'].strip()
500-
description = data.get('description', '')
501-
502-
# Basic domain validation
503-
# Allow wildcard domains (*.example.com) and regular domains
504-
domain_pattern = r'^(\*\.)?(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$'
505-
if not re.match(domain_pattern, domain):
506-
return jsonify({"status": "error", "message": "Invalid domain format"}), 400
522+
description = data['description']
507523

508524
conn = get_db()
509525
cursor = conn.cursor()

backend/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ sqlalchemy==2.0.36
99
python-dotenv==1.0.1
1010
pytz==2024.2
1111
werkzeug==3.1.3
12-
markupsafe==2.1.5
12+
markupsafe==2.1.5
13+
python-json-logger==2.0.7
14+
marshmallow==3.20.1

docker-compose.yml

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ services:
44
ports:
55
- "8011:8011"
66
volumes:
7-
- ./ui:/app
87
- ./config:/config
98
- ./data:/data
109
- ./logs:/logs
@@ -14,33 +13,45 @@ services:
1413
environment:
1514
- FLASK_ENV=production
1615
- BACKEND_URL=http://backend:5000
17-
- BASIC_AUTH_USERNAME=admin
18-
- BASIC_AUTH_PASSWORD=admin
1916
- REQUEST_TIMEOUT=30
2017
- MAX_RETRIES=5
2118
- BACKOFF_FACTOR=1.0
2219
- RETRY_WAIT_AFTER_STARTUP=10
20+
# Credentials must be provided via .env file or environment variables
21+
- BASIC_AUTH_USERNAME
22+
- BASIC_AUTH_PASSWORD
23+
- SECRET_KEY
24+
deploy:
25+
resources:
26+
limits:
27+
cpus: '0.50'
28+
memory: 512M
2329
networks:
2430
- proxy-network
2531
restart: unless-stopped
2632

2733
backend:
2834
build: ./backend
2935
ports:
30-
- "5001:5000" # Map container port 5000 to host port 5001
36+
- "5001:5000"
3137
volumes:
32-
- ./backend:/app
3338
- ./config:/config
3439
- ./data:/data
3540
- ./logs:/logs
36-
# Docker socket mount removed for security - see SECURITY.md
3741
environment:
3842
- FLASK_ENV=production
3943
- PROXY_HOST=proxy
4044
- PROXY_PORT=3128
41-
- BASIC_AUTH_USERNAME=admin
42-
- BASIC_AUTH_PASSWORD=admin
43-
- PROXY_CONTAINER_NAME=secure-proxy-proxy-1 # Add container name for restart commands
45+
- PROXY_CONTAINER_NAME=secure-proxy-proxy-1
46+
# Credentials must be provided via .env file or environment variables
47+
- BASIC_AUTH_USERNAME
48+
- BASIC_AUTH_PASSWORD
49+
- SECRET_KEY
50+
deploy:
51+
resources:
52+
limits:
53+
cpus: '0.50'
54+
memory: 512M
4455
healthcheck:
4556
test: ["CMD", "curl", "-f", "http://localhost:5000/health"]
4657
interval: 5s
@@ -56,11 +67,15 @@ services:
5667
ports:
5768
- "3128:3128"
5869
volumes:
59-
# Removed the direct mount of squid.conf
6070
- ./config:/config
6171
- ./data:/data
6272
- ./logs:/var/log/squid
6373
- squid-cache:/var/spool/squid
74+
deploy:
75+
resources:
76+
limits:
77+
cpus: '0.50'
78+
memory: 512M
6479
networks:
6580
- proxy-network
6681
restart: unless-stopped
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import unittest
2+
import json
3+
import os
4+
import sys
5+
from unittest.mock import patch, MagicMock
6+
7+
# Add backend to path
8+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../backend')))
9+
10+
# Set env vars before importing app
11+
os.environ['BASIC_AUTH_USERNAME'] = 'admin'
12+
os.environ['BASIC_AUTH_PASSWORD'] = 'admin'
13+
os.environ['SECRET_KEY'] = 'test_secret'
14+
15+
from app.app import app, init_db, get_db
16+
17+
class SecurityTests(unittest.TestCase):
18+
def setUp(self):
19+
app.config['TESTING'] = True
20+
self.client = app.test_client()
21+
22+
# Setup in-memory db
23+
with app.app_context():
24+
# Mock database path to use in-memory
25+
with patch('app.app.DATABASE_PATH', ':memory:'):
26+
init_db()
27+
28+
def test_input_validation_ip(self):
29+
# Test invalid IP
30+
response = self.client.post('/api/ip-blacklist',
31+
headers={'Authorization': 'Basic YWRtaW46YWRtaW4='}, # admin:admin
32+
json={'ip': 'invalid-ip'}
33+
)
34+
self.assertEqual(response.status_code, 400)
35+
36+
# Test valid IP
37+
response = self.client.post('/api/ip-blacklist',
38+
headers={'Authorization': 'Basic YWRtaW46YWRtaW4='},
39+
json={'ip': '1.2.3.4', 'description': 'test'}
40+
)
41+
self.assertEqual(response.status_code, 200)
42+
43+
def test_input_validation_domain(self):
44+
# Test invalid domain
45+
response = self.client.post('/api/domain-blacklist',
46+
headers={'Authorization': 'Basic YWRtaW46YWRtaW4='},
47+
json={'domain': '-invalid-domain'}
48+
)
49+
self.assertEqual(response.status_code, 400)
50+
51+
# Test valid domain
52+
response = self.client.post('/api/domain-blacklist',
53+
headers={'Authorization': 'Basic YWRtaW46YWRtaW4='},
54+
json={'domain': 'example.com', 'description': 'test'}
55+
)
56+
self.assertEqual(response.status_code, 200)
57+
58+
def test_settings_update(self):
59+
# Test update setting
60+
response = self.client.put('/api/settings/log_level',
61+
headers={'Authorization': 'Basic YWRtaW46YWRtaW4='},
62+
json={'value': 'debug'}
63+
)
64+
self.assertEqual(response.status_code, 200)
65+
66+
# Test invalid setting value
67+
response = self.client.put('/api/settings/log_level',
68+
headers={'Authorization': 'Basic YWRtaW46YWRtaW4='},
69+
json={'value': 'invalid'}
70+
)
71+
# validate_setting returns False, so it should be 400
72+
# In app.py: if not validate_setting(...): return ..., 400
73+
self.assertEqual(response.status_code, 400)
74+
75+
if __name__ == '__main__':
76+
unittest.main()

0 commit comments

Comments
 (0)