|
22 | 22 | from contextlib import contextmanager |
23 | 23 | import signal |
24 | 24 | import sys |
| 25 | +from pythonjsonlogger import jsonlogger |
| 26 | +from marshmallow import Schema, fields, validate, ValidationError |
| 27 | + |
| 28 | +class IPSchema(Schema): |
| 29 | + ip = fields.String(required=True) |
| 30 | + description = fields.String(missing='') |
| 31 | + |
| 32 | +class DomainSchema(Schema): |
| 33 | + domain = fields.String(required=True, validate=validate.Regexp(r'^(\*\.)?(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$')) |
| 34 | + description = fields.String(missing='') |
25 | 35 |
|
26 | 36 | # Rate limiting setup |
27 | 37 | auth_attempts = defaultdict(list) |
|
53 | 63 | os.makedirs('logs', exist_ok=True) |
54 | 64 | log_path = 'logs/backend.log' |
55 | 65 |
|
56 | | -logging.basicConfig( |
57 | | - level=logging.INFO, |
58 | | - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
59 | | - handlers=[ |
60 | | - logging.FileHandler(log_path), |
61 | | - logging.StreamHandler() |
62 | | - ] |
63 | | -) |
64 | 66 | logger = logging.getLogger(__name__) |
| 67 | +logHandler = logging.FileHandler(log_path) |
| 68 | +formatter = jsonlogger.JsonFormatter('%(asctime)s %(name)s %(levelname)s %(message)s') |
| 69 | +logHandler.setFormatter(formatter) |
| 70 | +logger.addHandler(logHandler) |
| 71 | +logger.setLevel(logging.INFO) |
| 72 | +logger.addHandler(logging.StreamHandler()) |
65 | 73 |
|
66 | 74 | # Database configuration |
67 | 75 | DATABASE_PATH = '/data/secure_proxy.db' |
@@ -146,19 +154,31 @@ def init_db(): |
146 | 154 | ) |
147 | 155 | ''') |
148 | 156 |
|
149 | | - # Insert default admin user if not exists |
150 | | - cursor.execute("SELECT COUNT(*) FROM users WHERE username = 'admin'") |
151 | | - if cursor.fetchone()[0] == 0: |
152 | | - # Use a known password hash for 'admin' to ensure UI can authenticate |
153 | | - admin_password_hash = generate_password_hash('admin') |
154 | | - logger.info(f"Creating default admin user with password hash") |
155 | | - cursor.execute("INSERT INTO users (username, password) VALUES (?, ?)", |
156 | | - ('admin', admin_password_hash)) |
| 157 | + # Handle admin user creation/update based on environment variables |
| 158 | + env_username = os.environ.get('BASIC_AUTH_USERNAME') |
| 159 | + env_password = os.environ.get('BASIC_AUTH_PASSWORD') |
| 160 | + |
| 161 | + if env_username and env_password: |
| 162 | + # Check if user exists |
| 163 | + cursor.execute("SELECT COUNT(*) FROM users WHERE username = ?", (env_username,)) |
| 164 | + if cursor.fetchone()[0] == 0: |
| 165 | + logger.info(f"Creating admin user '{env_username}' from environment variables") |
| 166 | + cursor.execute("INSERT INTO users (username, password) VALUES (?, ?)", |
| 167 | + (env_username, generate_password_hash(env_password))) |
| 168 | + else: |
| 169 | + logger.info(f"Updating admin user '{env_username}' from environment variables") |
| 170 | + cursor.execute("UPDATE users SET password = ? WHERE username = ?", |
| 171 | + (generate_password_hash(env_password), env_username)) |
157 | 172 | else: |
158 | | - # For existing installations, ensure the admin password hash is correct |
159 | | - cursor.execute("UPDATE users SET password = ? WHERE username = ?", |
160 | | - (generate_password_hash('admin'), 'admin')) |
161 | | - logger.info("Updated admin password hash to ensure authentication works") |
| 173 | + # Check if any user exists |
| 174 | + cursor.execute("SELECT COUNT(*) FROM users") |
| 175 | + if cursor.fetchone()[0] == 0: |
| 176 | + # No users and no env vars - generate random credentials |
| 177 | + gen_username = 'admin' |
| 178 | + gen_password = secrets.token_urlsafe(16) |
| 179 | + logger.warning(f"No credentials provided. Created default user '{gen_username}' with password: {gen_password}") |
| 180 | + cursor.execute("INSERT INTO users (username, password) VALUES (?, ?)", |
| 181 | + (gen_username, generate_password_hash(gen_password))) |
162 | 182 |
|
163 | 183 | # Insert default settings if not exists |
164 | 184 | default_settings = [ |
@@ -434,13 +454,14 @@ def get_ip_blacklist(): |
434 | 454 | @auth.login_required |
435 | 455 | def add_ip_to_blacklist(): |
436 | 456 | """Add an IP to the blacklist""" |
437 | | - data = request.get_json() |
438 | | - if not data or 'ip' not in data: |
439 | | - return jsonify({"status": "error", "message": "No IP provided"}), 400 |
| 457 | + try: |
| 458 | + data = IPSchema().load(request.get_json()) |
| 459 | + except ValidationError as err: |
| 460 | + return jsonify({"status": "error", "message": err.messages}), 400 |
440 | 461 |
|
441 | 462 | # Validate IP address format |
442 | 463 | ip = data['ip'].strip() |
443 | | - description = data.get('description', '') |
| 464 | + description = data['description'] |
444 | 465 |
|
445 | 466 | # Validate CIDR notation or single IP address |
446 | 467 | try: |
@@ -492,18 +513,13 @@ def get_domain_blacklist(): |
492 | 513 | @auth.login_required |
493 | 514 | def add_domain_to_blacklist(): |
494 | 515 | """Add a domain to the blacklist""" |
495 | | - data = request.get_json() |
496 | | - if not data or 'domain' not in data: |
497 | | - return jsonify({"status": "error", "message": "No domain provided"}), 400 |
| 516 | + try: |
| 517 | + data = DomainSchema().load(request.get_json()) |
| 518 | + except ValidationError as err: |
| 519 | + return jsonify({"status": "error", "message": err.messages}), 400 |
498 | 520 |
|
499 | 521 | domain = data['domain'].strip() |
500 | | - description = data.get('description', '') |
501 | | - |
502 | | - # Basic domain validation |
503 | | - # Allow wildcard domains (*.example.com) and regular domains |
504 | | - domain_pattern = r'^(\*\.)?(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)+([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$' |
505 | | - if not re.match(domain_pattern, domain): |
506 | | - return jsonify({"status": "error", "message": "Invalid domain format"}), 400 |
| 522 | + description = data['description'] |
507 | 523 |
|
508 | 524 | conn = get_db() |
509 | 525 | cursor = conn.cursor() |
|
0 commit comments