2323
2424from twisted .internet import address
2525from twisted .web .resource import IResource
26+ from twisted .web .server import Request
2627
2728import synapse
2829import synapse .events
@@ -190,7 +191,7 @@ def __init__(self, hs):
190191 self .http_client = hs .get_simple_http_client ()
191192 self .main_uri = hs .config .worker_main_http_uri
192193
193- async def on_POST (self , request , device_id ):
194+ async def on_POST (self , request : Request , device_id : Optional [ str ] ):
194195 requester = await self .auth .get_user_by_req (request , allow_guest = True )
195196 user_id = requester .user .to_string ()
196197 body = parse_json_object_from_request (request )
@@ -223,10 +224,12 @@ async def on_POST(self, request, device_id):
223224 header : request .requestHeaders .getRawHeaders (header , [])
224225 for header in (b"Authorization" , b"User-Agent" )
225226 }
226- # Add the previous hop the the X-Forwarded-For header.
227+ # Add the previous hop to the X-Forwarded-For header.
227228 x_forwarded_for = request .requestHeaders .getRawHeaders (
228229 b"X-Forwarded-For" , []
229230 )
231+ # we use request.client here, since we want the previous hop, not the
232+ # original client (as returned by request.getClientAddress()).
230233 if isinstance (request .client , (address .IPv4Address , address .IPv6Address )):
231234 previous_host = request .client .host .encode ("ascii" )
232235 # If the header exists, add to the comma-separated list of the first
@@ -239,6 +242,14 @@ async def on_POST(self, request, device_id):
239242 x_forwarded_for = [previous_host ]
240243 headers [b"X-Forwarded-For" ] = x_forwarded_for
241244
245+ # Replicate the original X-Forwarded-Proto header. Note that
246+ # XForwardedForRequest overrides isSecure() to give us the original protocol
247+ # used by the client, as opposed to the protocol used by our upstream proxy
248+ # - which is what we want here.
249+ headers [b"X-Forwarded-Proto" ] = [
250+ b"https" if request .isSecure () else b"http"
251+ ]
252+
242253 try :
243254 result = await self .http_client .post_json_get_json (
244255 self .main_uri + request .uri .decode ("ascii" ), body , headers = headers
0 commit comments