Skip to content

Commit a35be08

Browse files
committed
Java: buildess proxy tests: add mitm_proxy.py
A mock implementation of an https man-in-the-middle proxy
1 parent acbca9c commit a35be08

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import http.server
2+
import sys
3+
import os
4+
import socket
5+
import ssl
6+
import random
7+
from datetime import datetime, timedelta, timezone
8+
from cryptography.hazmat.primitives import hashes, serialization
9+
from cryptography import utils, x509
10+
from cryptography.hazmat.primitives.asymmetric import rsa, dsa
11+
12+
import select
13+
14+
15+
def generateCA(ca_cert_file, ca_key_file):
16+
ca_key = dsa.generate_private_key(4096)
17+
name = x509.Name([
18+
x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"),
19+
x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "GitHub"),
20+
x509.NameAttribute(x509.NameOID.COMMON_NAME, "GitHub CodeQL Proxy")])
21+
ca_cert = x509.CertificateBuilder().subject_name(name).issuer_name(name)
22+
ca_cert = ca_cert.public_key(ca_key.public_key())
23+
ca_cert = ca_cert.serial_number(random.randint(50000000, 100000000))
24+
ca_cert = ca_cert.not_valid_before(datetime.now(timezone.utc))
25+
ca_cert = ca_cert.not_valid_after(
26+
datetime.now(timezone.utc) + timedelta(days=3650))
27+
ca_cert = ca_cert.add_extension(x509.BasicConstraints(
28+
ca=True, path_length=None), critical=True)
29+
ca_cert = ca_cert.add_extension(
30+
x509.SubjectKeyIdentifier.from_public_key(ca_key.public_key()), critical=False)
31+
ca_cert = ca_cert.sign(ca_key, hashes.SHA256())
32+
with open(ca_cert_file, 'wb') as f:
33+
f.write(ca_cert.public_bytes(encoding=serialization.Encoding.PEM))
34+
with open(ca_key_file, 'wb') as f:
35+
f.write(ca_key.private_bytes(encoding=serialization.Encoding.PEM,
36+
format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()))
37+
38+
39+
def create_certificate(hostname):
40+
pkey = rsa.generate_private_key(public_exponent=65537, key_size=2048)
41+
subject = x509.Name(
42+
[x509.NameAttribute(x509.NameOID.COMMON_NAME, hostname)])
43+
44+
cert = x509.CertificateBuilder()
45+
cert = cert.subject_name(subject).issuer_name(ca_certificate.subject)
46+
cert = cert.public_key(pkey.public_key())
47+
cert = cert.serial_number(random.randint(50000000, 100000000))
48+
cert = cert.not_valid_before(datetime.now(timezone.utc)).not_valid_after(
49+
datetime.now(timezone.utc) + timedelta(days=3650))
50+
cert = cert.add_extension(x509.BasicConstraints(
51+
ca=False, path_length=None), critical=True)
52+
cert = cert.add_extension(
53+
x509.SubjectAlternativeName([x509.DNSName(hostname), x509.DNSName(f"*.{hostname}")]), critical=False)
54+
55+
cert = cert.sign(ca_key, hashes.SHA256())
56+
57+
return (cert, pkey)
58+
59+
60+
class Handler(http.server.SimpleHTTPRequestHandler):
61+
def check_auth(self):
62+
username = os.getenv('PROXY_USER')
63+
password = os.getenv('PROXY_PASSWORD')
64+
if username is None or password is None:
65+
return True
66+
67+
authorization = self.headers.get(
68+
'Proxy-Authorization', self.headers.get('Authorization', ''))
69+
authorization = authorization.split()
70+
if len(authorization) == 2:
71+
import base64
72+
import binascii
73+
auth_type = authorization[0]
74+
if auth_type.lower() == "basic":
75+
try:
76+
authorization = authorization[1].encode('ascii')
77+
authorization = base64.decodebytes(
78+
authorization).decode('ascii')
79+
except (binascii.Error, UnicodeError):
80+
pass
81+
else:
82+
authorization = authorization.split(':')
83+
if len(authorization) == 2:
84+
return username == authorization[0] and password == authorization[1]
85+
return False
86+
87+
def do_CONNECT(self):
88+
if not self.check_auth():
89+
self.send_response(
90+
http.HTTPStatus.PROXY_AUTHENTICATION_REQUIRED)
91+
self.send_header('Proxy-Authenticate', 'Basic realm="Proxy"')
92+
self.end_headers()
93+
return
94+
# split self.path into host and port
95+
host, port = self.path.split(':')
96+
port = int(port)
97+
self.send_response(http.HTTPStatus.OK, 'Connection established')
98+
self.send_header('Connection', 'close')
99+
self.end_headers()
100+
self.mitm(host, port)
101+
102+
# man in the middle SSL connection
103+
def mitm(self, host, port):
104+
ssl_client_context = ssl.create_default_context(
105+
purpose=ssl.Purpose.CLIENT_AUTH)
106+
if not os.path.exists("certs/" + host + '.pem'):
107+
cert, pkey = create_certificate(host)
108+
with open("certs/" + host + '.pem', 'wb') as f:
109+
f.write(pkey.private_bytes(encoding=serialization.Encoding.PEM,
110+
format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption()))
111+
f.write(cert.public_bytes(encoding=serialization.Encoding.PEM))
112+
113+
ssl_client_context.load_cert_chain("certs/" + host + '.pem')
114+
ssl_client_context.load_verify_locations(ca_certificate_path)
115+
# wrap self.connection in SSL
116+
client = ssl_client_context.wrap_socket(
117+
self.connection, server_side=True)
118+
119+
# create socket to host:port
120+
remote = socket.create_connection(
121+
(host, port))
122+
# wrap socket in SSL
123+
ssl_server_context = ssl.create_default_context(
124+
purpose=ssl.Purpose.SERVER_AUTH)
125+
remote = ssl_server_context.wrap_socket(remote, server_hostname=host)
126+
127+
try:
128+
while True:
129+
ready, _, _ = select.select(
130+
[client, remote], [], [], 2.0)
131+
if not ready:
132+
break
133+
for src in ready:
134+
if src is client:
135+
dst = remote
136+
else:
137+
dst = client
138+
src.setblocking(False)
139+
dst.setblocking(True)
140+
pending = 8192
141+
while pending:
142+
try:
143+
data = src.recv(pending)
144+
except ssl.SSLWantReadError:
145+
break
146+
if not data:
147+
return
148+
pending = src.pending()
149+
dst.sendall(data)
150+
finally:
151+
remote.close()
152+
client.close()
153+
154+
def do_GET(self):
155+
raise NotImplementedError()
156+
157+
158+
if __name__ == '__main__':
159+
port = int(sys.argv[1])
160+
ca_certificate = None
161+
ca_certificate_path = None
162+
ca_key = None
163+
if len(sys.argv) > 2:
164+
ca_certificate_path = sys.argv[2]
165+
with open(ca_certificate_path, 'rb') as f:
166+
ca_certificate = x509.load_pem_x509_certificate(f.read())
167+
with open(sys.argv[3], 'rb') as f:
168+
ca_key = serialization.load_pem_private_key(
169+
f.read(), password=None)
170+
171+
server_address = ('localhost', port)
172+
httpd = http.server.HTTPServer(server_address, Handler)
173+
httpd.serve_forever()

0 commit comments

Comments
 (0)