Skip to content

Commit afe2e62

Browse files
committed
Rewrite Host header during reverse proxy
1 parent 9077c16 commit afe2e62

File tree

3 files changed

+41
-16
lines changed

3 files changed

+41
-16
lines changed

proxy/core/connection/connection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def connection(self) -> TcpOrTlsSocket:
4949

5050
def send(self, data: Union[memoryview, bytes]) -> int:
5151
"""Users must handle BrokenPipeError exceptions"""
52-
# logger.info(data.tobytes())
5352
return self.connection.send(data)
5453

5554
def recv(
@@ -67,7 +66,7 @@ def recv(
6766
return memoryview(data)
6867

6968
def close(self) -> bool:
70-
if not self.closed:
69+
if not self.closed and self.connection:
7170
self.connection.close()
7271
self.closed = True
7372
return self.closed
@@ -97,8 +96,9 @@ def flush(self, max_send_size: Optional[int] = None) -> int:
9796
self._num_buffer -= 1
9897
else:
9998
self.buffer[0] = mv[sent:]
100-
del mv
10199
logger.debug('flushed %d bytes to %s' % (sent, self.tag))
100+
# logger.info(mv[:sent].tobytes())
101+
del mv
102102
return sent
103103

104104
def is_reusable(self) -> bool:

proxy/http/parser/parser.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,12 @@ def parse(
283283
self.state = httpParserStates.COMPLETE
284284
self.buffer = None if raw == b'' else raw
285285

286-
def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = False) -> bytes:
286+
def build(
287+
self,
288+
disable_headers: Optional[List[bytes]] = None,
289+
for_proxy: bool = False,
290+
host: Optional[bytes] = None,
291+
) -> bytes:
287292
"""Rebuild the request object."""
288293
assert self.method and self.version and self.type == httpParserTypes.REQUEST_PARSER
289294
if disable_headers is None:
@@ -301,11 +306,22 @@ def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool =
301306
path
302307
) if not self._is_https_tunnel else (self.host + COLON + str(self.port).encode())
303308
return build_http_request(
304-
self.method, path, self.version,
305-
headers={} if not self.headers else {
306-
self.headers[k][0]: self.headers[k][1] for k in self.headers if
307-
k.lower() not in disable_headers
308-
},
309+
self.method,
310+
path,
311+
self.version,
312+
headers=(
313+
{}
314+
if not self.headers
315+
else {
316+
self.headers[k][0]: (
317+
self.headers[k][1]
318+
if host is None or self.headers[k][0].lower() != b'host'
319+
else host
320+
)
321+
for k in self.headers
322+
if k.lower() not in disable_headers
323+
}
324+
),
309325
body=body,
310326
no_ua=True,
311327
)

proxy/http/server/reverse.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from proxy.common.utils import text_
2121
from proxy.http.exception import HttpProtocolException
2222
from proxy.common.constants import (
23-
HTTPS_PROTO, DEFAULT_HTTP_PORT, DEFAULT_HTTPS_PORT,
23+
COLON, HTTP_PROTO, HTTPS_PROTO, DEFAULT_HTTP_PORT, DEFAULT_HTTPS_PORT,
2424
DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT,
2525
)
2626
from ...common.types import Readables, Writables, Descriptors
@@ -111,23 +111,32 @@ def handle_request(self, request: HttpParser) -> None:
111111
assert self.choice and self.choice.hostname
112112
port = (
113113
self.choice.port or DEFAULT_HTTP_PORT
114-
if self.choice.scheme == b'http'
115-
else DEFAULT_HTTPS_PORT
114+
if self.choice.scheme == HTTP_PROTO
115+
else self.choice.port or DEFAULT_HTTPS_PORT
116116
)
117117
self.initialize_upstream(text_(self.choice.hostname), port)
118118
assert self.upstream
119119
try:
120120
self.upstream.connect()
121121
if self.choice.scheme == HTTPS_PROTO:
122122
self.upstream.wrap(
123-
text_(
124-
self.choice.hostname,
125-
),
123+
text_(self.choice.hostname),
126124
as_non_blocking=True,
127125
ca_file=self.flags.ca_file,
128126
)
129127
request.path = self.choice.remainder
130-
self.upstream.queue(memoryview(request.build()))
128+
self.upstream.queue(
129+
memoryview(
130+
request.build(
131+
host=self.choice.hostname
132+
+ (
133+
COLON + self.choice.port.to_bytes()
134+
if self.choice.port is not None
135+
else b''
136+
),
137+
),
138+
),
139+
)
131140
except ConnectionRefusedError:
132141
raise HttpProtocolException( # pragma: no cover
133142
'Connection refused by upstream server {0}:{1}'.format(

0 commit comments

Comments
 (0)