1212from http .server import ThreadingHTTPServer , BaseHTTPRequestHandler
1313import logging
1414import struct
15+ import httpx
16+ from dataclasses import dataclass
17+ from typing import Dict
18+ import ssl
1519
1620
1721class GrpcWebProxy (object ):
@@ -26,15 +30,20 @@ def __init__(self, scheduler: Scheduler, grpc_port: int):
2630 self .logger .info (
2731 f"GrpcWebProxy configured to forward requests from web_port={ self .web_port } to grpc_port={ self .grpc_port } "
2832 )
33+ self .handler_cls = Handler
2934
3035 def start (self ):
3136 self ._thread = Thread (target = self .run , daemon = True )
3237 self .logger .info (f"Starting grpc-web-proxy on port { self .web_port } " )
3338 self .running = True
3439 server_address = ("127.0.0.1" , self .web_port )
3540
36- self .httpd = ThreadingHTTPServer (server_address , Handler )
41+ self .httpd = ThreadingHTTPServer (server_address , self . handler_cls )
3742 self .httpd .grpc_port = self .grpc_port
43+
44+ # Just a simple way to pass the scheduler to the handler
45+ self .httpd .scheduler = self .scheduler
46+
3847 self .logger .debug (f"Server startup complete" )
3948 self ._thread .start ()
4049
@@ -47,11 +56,49 @@ def stop(self):
4756 self ._thread .join ()
4857
4958
59+ @dataclass
60+ class Request :
61+ body : bytes
62+ headers : Dict [str , str ]
63+ flags : int
64+ length : int
65+
66+
67+ @dataclass
68+ class Response :
69+ body : bytes
70+
71+
5072class Handler (BaseHTTPRequestHandler ):
5173 def __init__ (self , * args , ** kwargs ):
5274 self .logger = logging .getLogger ("gltesting.grpcweb.Handler" )
5375 BaseHTTPRequestHandler .__init__ (self , * args , ** kwargs )
5476
77+ def proxy (self , request ) -> Response :
78+ """Callback called with the request, implementing the proxying."""
79+ url = f"http://localhost:{ self .server .grpc_port } { self .path } "
80+ self .logger .debug (f"Forwarding request to '{ url } '" )
81+ headers = {
82+ "te" : "trailers" ,
83+ "Content-Type" : "application/grpc" ,
84+ "grpc-accept-encoding" : "identity" ,
85+ "user-agent" : "gl-testing-grpc-web-proxy" ,
86+ }
87+ content = struct .pack ("!cI" , request .flags , request .length ) + request .body
88+ req = httpx .Request (
89+ "POST" ,
90+ url ,
91+ headers = headers ,
92+ content = content ,
93+ )
94+ client = httpx .Client (http1 = False , http2 = True )
95+ res = client .send (req )
96+ return Response (body = res .content )
97+
98+ def auth (self , request : Request ) -> bool :
99+ """Authenticate the request. True means allow."""
100+ return True
101+
55102 def do_POST (self ):
56103 # We don't actually touch the payload, so we do not really
57104 # care about the flags ourselves. The upstream sysmte will
@@ -69,40 +116,95 @@ def do_POST(self):
69116 # need to decode it, and we can treat it as opaque blob.
70117 body = self .rfile .read (length )
71118
119+ req = Request (body = body , headers = self .headers , flags = flags , length = length )
120+ if not self .auth (req ):
121+ self .wfile .write (b"HTTP/1.1 401 Unauthorized\r \n \r \n " )
122+ return
123+
124+ response = self .proxy (req )
125+ self .wfile .write (b"HTTP/1.0 200 OK\n \n " )
126+ self .wfile .write (response .body )
127+ self .wfile .flush ()
128+
129+
130+ class NodeHandler (Handler ):
131+ """A handler that is aware of nodes, their auth and how they schedule."""
132+
133+ def __init__ (self , * args , ** kwargs ):
134+ self .logger = logging .getLogger ("gltesting.grpcweb.NodeHandler" )
135+ BaseHTTPRequestHandler .__init__ (self , * args , ** kwargs )
136+
137+ def auth (self , request : Request ) -> bool :
72138 # TODO extract the `glauthpubkey` and the `glauthsig`, then
73139 # verify them. Fail the call if the verification fails,
74140 # forward otherwise.
75141 # This is just a test server, and we don't make use of the
76142 # multiplexing support in `h2`, which simplifies this proxy
77143 # quite a bit. The production server maintains a cache of
78144 # connections and multiplexes correctly.
145+ pk = request .headers .get ("glauthpubkey" , None )
146+ sig = request .headers .get ("glauthsig" , None )
147+ ts = request .headers .get ("glts" , None )
79148
80- import httpx
149+ if not pk :
150+ self .logger .warn (f"Missing public key header" )
151+ return False
81152
82- url = f"http://localhost:{ self .server .grpc_port } { self .path } "
83- self .logger .debug (f"Forwarding request to '{ url } '" )
153+ if not sig :
154+ self .logger .warn (f"Missing signature header" )
155+ return False
156+
157+ if not ts :
158+ self .logger .warn (f"Missing timestamp header" )
159+ return False
160+
161+ # TODO Check the signature.
162+ return True
163+
164+ def proxy (self , request : Request ):
165+ # Fetch current location of the node
166+
167+ pk = request .headers .get ("glauthpubkey" )
168+ from base64 import b64decode
169+
170+ pk = b64decode (pk )
171+
172+ node = self .server .scheduler .get_node (pk )
173+ self .logger .debug (f"Found node for node_id={ pk .hex ()} " )
174+
175+ # TODO Schedule node if not scheduled
176+
177+ client_cert = node .identity .private_key
178+ ca_path = node .identity .caroot_path
179+
180+ # Load TLS client cert info client
181+ ctx = httpx .create_ssl_context (
182+ verify = ca_path ,
183+ http2 = True ,
184+ cert = (
185+ node .identity .cert_chain_path ,
186+ node .identity .private_key_path ,
187+ ),
188+ )
189+ client = httpx .Client (http1 = False , http2 = True , verify = ctx )
190+
191+ url = f"{ node .process .grpc_uri } { self .path } "
84192 headers = {
85193 "te" : "trailers" ,
86194 "Content-Type" : "application/grpc" ,
87- "grpc-accept-encoding" : "idenity " ,
88- "user-agent" : "My bloody hacked up script " ,
195+ "grpc-accept-encoding" : "identity " ,
196+ "user-agent" : "gl-testing-grpc-web-proxy " ,
89197 }
90- content = struct .pack ("!cI" , flags , length ) + body
198+ content = struct .pack ("!cI" , request .flags , request .length ) + request .body
199+
200+ # Forward request
91201 req = httpx .Request (
92202 "POST" ,
93203 url ,
94204 headers = headers ,
95205 content = content ,
96206 )
97- client = httpx .Client (http1 = False , http2 = True )
98-
99- res = client .send (req )
100207 res = client .send (req )
101208
102- canned = b"\n \r heklllo world"
103- l = struct .pack ("!I" , len (canned ))
104- self .wfile .write (b"HTTP/1.0 200 OK\n \n " )
105- self .wfile .write (b"\x00 " )
106- self .wfile .write (l )
107- self .wfile .write (canned )
108- self .wfile .flush ()
209+ # Return response
210+ return Response (body = res .content )
0 commit comments