11import argparse
22import asyncio
3+ import os
4+ import re
35import signal
6+ import socket
7+ import ssl
8+ import sys
49import time
510from collections .abc import AsyncIterator , Iterator
6- from contextlib import ExitStack
7- from ssl import VerifyMode
11+ from contextlib import ExitStack , closing
812from tempfile import NamedTemporaryFile
913from typing import TYPE_CHECKING , Literal , TypeVar
1014
3640 UnaryResponseDefinition ,
3741)
3842from google .protobuf .any_pb2 import Any
39- from hypercorn .asyncio import serve as hypercorn_serve
4043from hypercorn .config import Config as HypercornConfig
4144from hypercorn .logging import Logger
4245
@@ -402,76 +405,228 @@ async def info(self, message: str, *args: Any, **kwargs: Any) -> None:
402405 await super ().info (message , * args , ** kwargs )
403406
404407
405- async def serve (
406- request : ServerCompatRequest , mode : Literal ["sync" , "async" ]
407- ) -> tuple [asyncio .Task , int ]:
408- read_max_bytes = request .message_receive_limit or None
409- match mode :
410- case "async" :
411- app = ConformanceServiceASGIApplication (
412- TestService (), read_max_bytes = read_max_bytes
413- )
414- case "sync" :
415- app = ConformanceServiceWSGIApplication (
416- TestServiceSync (), read_max_bytes = read_max_bytes
417- )
418-
419- conf = HypercornConfig ()
420- conf .bind = ["127.0.0.1:0" ]
408+ read_max_bytes = os .getenv ("READ_MAX_BYTES" )
409+ if read_max_bytes is not None :
410+ read_max_bytes = int (read_max_bytes )
421411
422- cleanup = ExitStack ()
423- if request .use_tls :
424- cert_file = cleanup .enter_context (NamedTemporaryFile ())
425- key_file = cleanup .enter_context (NamedTemporaryFile ())
426- cert_file .write (request .server_creds .cert )
427- cert_file .flush ()
428- key_file .write (request .server_creds .key )
429- key_file .flush ()
430- conf .certfile = cert_file .name
431- conf .keyfile = key_file .name
432- if request .client_tls_cert :
433- ca_cert_file = cleanup .enter_context (NamedTemporaryFile ())
434- ca_cert_file .write (request .client_tls_cert )
435- ca_cert_file .flush ()
436- conf .ca_certs = ca_cert_file .name
437- conf .verify_mode = VerifyMode .CERT_REQUIRED
438-
439- conf ._log = PortCapturingLogger (conf )
412+ asgi_app = ConformanceServiceASGIApplication (
413+ TestService (), read_max_bytes = read_max_bytes
414+ )
415+ wsgi_app = ConformanceServiceWSGIApplication (
416+ TestServiceSync (), read_max_bytes = read_max_bytes
417+ )
440418
441- shutdown_event = asyncio .Event ()
442419
443- def _signal_handler (* _ ) -> None :
444- cleanup .close ()
445- shutdown_event .set ()
420+ def _server_env (request : ServerCompatRequest ) -> dict [str , str ]:
421+ pythonpath = os .pathsep .join (sys .path )
422+ env = {
423+ ** os .environ ,
424+ "PYTHONPATH" : pythonpath ,
425+ "PYTHONHOME" : f"{ sys .prefix } :{ sys .exec_prefix } " ,
426+ }
427+ if request .message_receive_limit :
428+ env ["READ_MAX_BYTES" ] = str (request .message_receive_limit )
429+ return env
430+
431+
432+ _port_regex = re .compile (r".*://[^:]+:(\d+).*" )
433+
434+
435+ async def serve_granian (
436+ request : ServerCompatRequest ,
437+ mode : Literal ["sync" , "async" ],
438+ certfile : str | None ,
439+ keyfile : str | None ,
440+ cafile : str | None ,
441+ port_future : asyncio .Future [int ],
442+ ):
443+ # Granian seems to have a bug that it prints out 0 rather than the resolved port,
444+ # so we need to determine it ourselves. If we see race conditions because of it,
445+ # we can set max-servers=1 in the runner.
446+ port = _find_free_port ()
447+ args = [f"--port={ port } " ]
448+ if certfile :
449+ args .append (f"--ssl-certificate={ certfile } " )
450+ if keyfile :
451+ args .append (f"--ssl-keyfile={ keyfile } " )
452+ if cafile :
453+ args .append (f"--ssl-ca={ cafile } " )
454+ args .append ("--ssl-client-verify" )
455+
456+ if mode == "sync" :
457+ args .append ("--interface=wsgi" )
458+ args .append ("server:wsgi_app" )
459+ else :
460+ args .append ("--interface=asgi" )
461+ args .append ("server:asgi_app" )
462+
463+ proc = await asyncio .create_subprocess_exec (
464+ "granian" ,
465+ * args ,
466+ stderr = asyncio .subprocess .STDOUT ,
467+ stdout = asyncio .subprocess .PIPE ,
468+ limit = 1024 ,
469+ env = _server_env (request ),
470+ )
471+ stdout = proc .stdout
472+ assert stdout is not None # noqa: S101
473+ try :
474+ for _ in range (100 ):
475+ line = await stdout .readline ()
476+ if b"Listening at:" in line :
477+ break
478+ port_future .set_result (port )
479+ await proc .wait ()
480+ except asyncio .CancelledError :
481+ proc .terminate ()
482+ await proc .wait ()
483+
484+
485+ async def serve_gunicorn (
486+ request : ServerCompatRequest ,
487+ certfile : str | None ,
488+ keyfile : str | None ,
489+ cafile : str | None ,
490+ port_future : asyncio .Future [int ],
491+ ):
492+ args = ["--bind=127.0.0.1:0" , "--workers=8" ]
493+ if certfile :
494+ args .append (f"--certfile={ certfile } " )
495+ if keyfile :
496+ args .append (f"--keyfile={ keyfile } " )
497+ if cafile :
498+ args .append (f"--ca-certs={ cafile } " )
499+ args .append (f"--cert-reqs={ ssl .CERT_REQUIRED } " )
500+
501+ args .append ("server:wsgi_app" )
502+
503+ proc = await asyncio .create_subprocess_exec (
504+ "gunicorn" ,
505+ * args ,
506+ stderr = asyncio .subprocess .STDOUT ,
507+ stdout = asyncio .subprocess .PIPE ,
508+ limit = 1024 ,
509+ env = _server_env (request ),
510+ )
511+ stdout = proc .stdout
512+ assert stdout is not None # noqa: S101
513+ try :
514+ for _ in range (100 ):
515+ line = await stdout .readline ()
516+ match = _port_regex .match (line .decode ("utf-8" ))
517+ if match :
518+ port_future .set_result (int (match .group (1 )))
519+ break
520+ await proc .wait ()
521+ except asyncio .CancelledError :
522+ proc .terminate ()
523+ await proc .wait ()
524+
525+
526+ async def serve_hypercorn (
527+ request : ServerCompatRequest ,
528+ mode : Literal ["sync" , "async" ],
529+ certfile : str | None ,
530+ keyfile : str | None ,
531+ cafile : str | None ,
532+ port_future : asyncio .Future [int ],
533+ ):
534+ args = ["--bind=localhost:0" ]
535+ if certfile :
536+ args .append (f"--certfile={ certfile } " )
537+ if keyfile :
538+ args .append (f"--keyfile={ keyfile } " )
539+ if cafile :
540+ args .append (f"--ca-certs={ cafile } " )
541+ args .append ("--verify-mode=CERT_REQUIRED" )
542+
543+ if mode == "sync" :
544+ args .append ("server:wsgi_app" )
545+ else :
546+ args .append ("server:asgi_app" )
547+
548+ proc = await asyncio .create_subprocess_exec (
549+ "hypercorn" ,
550+ * args ,
551+ stderr = asyncio .subprocess .STDOUT ,
552+ stdout = asyncio .subprocess .PIPE ,
553+ limit = 1024 ,
554+ env = _server_env (request ),
555+ )
556+ stdout = proc .stdout
557+ assert stdout is not None # noqa: S101
558+ try :
559+ for _ in range (100 ):
560+ line = await stdout .readline ()
561+ match = _port_regex .match (line .decode ("utf-8" ))
562+ if match :
563+ port_future .set_result (int (match .group (1 )))
564+ break
565+ await proc .wait ()
566+ except asyncio .CancelledError :
567+ proc .terminate ()
568+ await proc .wait ()
569+
570+
571+ async def serve_uvicorn (
572+ request : ServerCompatRequest ,
573+ certfile : str | None ,
574+ keyfile : str | None ,
575+ cafile : str | None ,
576+ port_future : asyncio .Future [int ],
577+ ):
578+ args = ["--port=0" , "--no-access-log" ]
579+ if certfile :
580+ args .append (f"--ssl-certfile={ certfile } " )
581+ if keyfile :
582+ args .append (f"--ssl-keyfile={ keyfile } " )
583+ if cafile :
584+ args .append (f"--ssl-ca-certs={ cafile } " )
585+ args .append (f"--ssl-cert-reqs={ ssl .CERT_REQUIRED } " )
586+
587+ args .append ("server:asgi_app" )
588+
589+ proc = await asyncio .create_subprocess_exec (
590+ "uvicorn" ,
591+ * args ,
592+ stderr = asyncio .subprocess .STDOUT ,
593+ stdout = asyncio .subprocess .PIPE ,
594+ limit = 1024 ,
595+ env = _server_env (request ),
596+ )
597+ stdout = proc .stdout
598+ assert stdout is not None # noqa: S101
599+ try :
600+ for _ in range (100 ):
601+ line = await stdout .readline ()
602+ match = _port_regex .match (line .decode ("utf-8" ))
603+ if match :
604+ port_future .set_result (int (match .group (1 )))
605+ break
606+ await proc .wait ()
607+ except asyncio .CancelledError :
608+ proc .terminate ()
609+ await proc .wait ()
446610
447- loop = asyncio .get_event_loop ()
448- loop .add_signal_handler (signal .SIGTERM , _signal_handler )
449- loop .add_signal_handler (signal .SIGINT , _signal_handler )
450611
451- serve_task = loop .create_task (
452- hypercorn_serve (
453- app , # pyright:ignore[reportArgumentType] - some incompatibility in type
454- conf ,
455- shutdown_trigger = shutdown_event .wait ,
456- mode = "asgi" if mode == "async" else "wsgi" ,
457- )
458- )
459- port = - 1
460- for _ in range (100 ):
461- port = conf ._log .port
462- if port != - 1 :
463- break
464- await asyncio .sleep (0.01 )
465- return serve_task , port
612+ def _find_free_port ():
613+ with closing (socket .socket (socket .AF_INET , socket .SOCK_STREAM )) as s :
614+ s .bind (("" , 0 ))
615+ s .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
616+ return s .getsockname ()[1 ]
466617
467618
468619class Args (argparse .Namespace ):
469620 mode : Literal ["sync" , "async" ]
621+ server : Literal ["granian" , "hypercorn" , "uvicorn" ]
470622
471623
472624async def main () -> None :
473- parser = argparse .ArgumentParser (description = "Conformance client " )
625+ parser = argparse .ArgumentParser (description = "Conformance server " )
474626 parser .add_argument ("--mode" , choices = ["sync" , "async" ])
627+ parser .add_argument (
628+ "--server" , choices = ["granian" , "gunicorn" , "hypercorn" , "uvicorn" ]
629+ )
475630 args = parser .parse_args (namespace = Args ())
476631
477632 stdin , stdout = await create_standard_streams ()
@@ -485,19 +640,67 @@ async def main() -> None:
485640 request = ServerCompatRequest ()
486641 request .ParseFromString (request_buf )
487642
488- serve_task , port = await serve ( request , args . mode )
489- response = ServerCompatResponse ()
490- response . host = "127.0.0.1"
491- response . port = port
643+ cleanup = ExitStack ( )
644+ certfile = None
645+ keyfile = None
646+ cafile = None
492647 if request .use_tls :
493- response .pem_cert = request .server_creds .cert
494- response_buf = response .SerializeToString ()
495- size_buf = len (response_buf ).to_bytes (4 , byteorder = "big" )
496- stdout .write (size_buf )
497- stdout .write (response_buf )
498- await stdout .drain ()
499- # Runner will send sigterm which is handled by serve
500- await serve_task
648+ cert_file = cleanup .enter_context (NamedTemporaryFile ())
649+ key_file = cleanup .enter_context (NamedTemporaryFile ())
650+ cert_file .write (request .server_creds .cert )
651+ cert_file .flush ()
652+ key_file .write (request .server_creds .key )
653+ key_file .flush ()
654+ certfile = cert_file .name
655+ keyfile = key_file .name
656+ if request .client_tls_cert :
657+ ca_cert_file = cleanup .enter_context (NamedTemporaryFile ())
658+ ca_cert_file .write (request .client_tls_cert )
659+ ca_cert_file .flush ()
660+ cafile = ca_cert_file .name
661+
662+ with cleanup :
663+ port_future : asyncio .Future [int ] = asyncio .get_event_loop ().create_future ()
664+ match args .server :
665+ case "granian" :
666+ serve_task = asyncio .create_task (
667+ serve_granian (
668+ request , args .mode , certfile , keyfile , cafile , port_future
669+ )
670+ )
671+ case "gunicorn" :
672+ if args .mode == "async" :
673+ msg = "gunicorn does not support async mode"
674+ raise ValueError (msg )
675+ serve_task = asyncio .create_task (
676+ serve_gunicorn (request , certfile , keyfile , cafile , port_future )
677+ )
678+ case "hypercorn" :
679+ serve_task = asyncio .create_task (
680+ serve_hypercorn (
681+ request , args .mode , certfile , keyfile , cafile , port_future
682+ )
683+ )
684+ case "uvicorn" :
685+ if args .mode == "sync" :
686+ msg = "uvicorn does not support sync mode"
687+ raise ValueError (msg )
688+ serve_task = asyncio .create_task (
689+ serve_uvicorn (request , certfile , keyfile , cafile , port_future )
690+ )
691+ port = await port_future
692+ response = ServerCompatResponse ()
693+ response .host = "127.0.0.1"
694+ response .port = port
695+ if request .use_tls :
696+ response .pem_cert = request .server_creds .cert
697+ response_buf = response .SerializeToString ()
698+ size_buf = len (response_buf ).to_bytes (4 , byteorder = "big" )
699+ stdout .write (size_buf )
700+ stdout .write (response_buf )
701+ await stdout .drain ()
702+ asyncio .get_event_loop ().add_signal_handler (signal .SIGTERM , serve_task .cancel )
703+ await serve_task
501704
502705
503706if __name__ == "__main__" :
0 commit comments