Skip to content

Commit d9ec8aa

Browse files
committed
Run conformance tests with granian, gunicorn, uvicorn too.
Signed-off-by: Anuraag Agrawal <[email protected]>
1 parent fe71e4f commit d9ec8aa

File tree

4 files changed

+439
-101
lines changed

4 files changed

+439
-101
lines changed

conformance/test/server.py

Lines changed: 275 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import argparse
22
import asyncio
3+
import os
4+
import re
35
import signal
6+
import socket
7+
import ssl
8+
import sys
49
import time
510
from collections.abc import AsyncIterator, Iterator
6-
from contextlib import ExitStack
7-
from ssl import VerifyMode
11+
from contextlib import ExitStack, closing
812
from tempfile import NamedTemporaryFile
913
from typing import TYPE_CHECKING, Literal, TypeVar
1014

@@ -36,7 +40,6 @@
3640
UnaryResponseDefinition,
3741
)
3842
from google.protobuf.any_pb2 import Any
39-
from hypercorn.asyncio import serve as hypercorn_serve
4043
from hypercorn.config import Config as HypercornConfig
4144
from hypercorn.logging import Logger
4245

@@ -402,76 +405,228 @@ async def info(self, message: str, *args: Any, **kwargs: Any) -> None:
402405
await super().info(message, *args, **kwargs)
403406

404407

405-
async def serve(
406-
request: ServerCompatRequest, mode: Literal["sync", "async"]
407-
) -> tuple[asyncio.Task, int]:
408-
read_max_bytes = request.message_receive_limit or None
409-
match mode:
410-
case "async":
411-
app = ConformanceServiceASGIApplication(
412-
TestService(), read_max_bytes=read_max_bytes
413-
)
414-
case "sync":
415-
app = ConformanceServiceWSGIApplication(
416-
TestServiceSync(), read_max_bytes=read_max_bytes
417-
)
418-
419-
conf = HypercornConfig()
420-
conf.bind = ["127.0.0.1:0"]
408+
read_max_bytes = os.getenv("READ_MAX_BYTES")
409+
if read_max_bytes is not None:
410+
read_max_bytes = int(read_max_bytes)
421411

422-
cleanup = ExitStack()
423-
if request.use_tls:
424-
cert_file = cleanup.enter_context(NamedTemporaryFile())
425-
key_file = cleanup.enter_context(NamedTemporaryFile())
426-
cert_file.write(request.server_creds.cert)
427-
cert_file.flush()
428-
key_file.write(request.server_creds.key)
429-
key_file.flush()
430-
conf.certfile = cert_file.name
431-
conf.keyfile = key_file.name
432-
if request.client_tls_cert:
433-
ca_cert_file = cleanup.enter_context(NamedTemporaryFile())
434-
ca_cert_file.write(request.client_tls_cert)
435-
ca_cert_file.flush()
436-
conf.ca_certs = ca_cert_file.name
437-
conf.verify_mode = VerifyMode.CERT_REQUIRED
438-
439-
conf._log = PortCapturingLogger(conf)
412+
asgi_app = ConformanceServiceASGIApplication(
413+
TestService(), read_max_bytes=read_max_bytes
414+
)
415+
wsgi_app = ConformanceServiceWSGIApplication(
416+
TestServiceSync(), read_max_bytes=read_max_bytes
417+
)
440418

441-
shutdown_event = asyncio.Event()
442419

443-
def _signal_handler(*_) -> None:
444-
cleanup.close()
445-
shutdown_event.set()
420+
def _server_env(request: ServerCompatRequest) -> dict[str, str]:
421+
pythonpath = os.pathsep.join(sys.path)
422+
env = {
423+
**os.environ,
424+
"PYTHONPATH": pythonpath,
425+
"PYTHONHOME": f"{sys.prefix}:{sys.exec_prefix}",
426+
}
427+
if request.message_receive_limit:
428+
env["READ_MAX_BYTES"] = str(request.message_receive_limit)
429+
return env
430+
431+
432+
_port_regex = re.compile(r".*://[^:]+:(\d+).*")
433+
434+
435+
async def serve_granian(
436+
request: ServerCompatRequest,
437+
mode: Literal["sync", "async"],
438+
certfile: str | None,
439+
keyfile: str | None,
440+
cafile: str | None,
441+
port_future: asyncio.Future[int],
442+
):
443+
# Granian seems to have a bug that it prints out 0 rather than the resolved port,
444+
# so we need to determine it ourselves. If we see race conditions because of it,
445+
# we can set max-servers=1 in the runner.
446+
port = _find_free_port()
447+
args = [f"--port={port}"]
448+
if certfile:
449+
args.append(f"--ssl-certificate={certfile}")
450+
if keyfile:
451+
args.append(f"--ssl-keyfile={keyfile}")
452+
if cafile:
453+
args.append(f"--ssl-ca={cafile}")
454+
args.append("--ssl-client-verify")
455+
456+
if mode == "sync":
457+
args.append("--interface=wsgi")
458+
args.append("server:wsgi_app")
459+
else:
460+
args.append("--interface=asgi")
461+
args.append("server:asgi_app")
462+
463+
proc = await asyncio.create_subprocess_exec(
464+
"granian",
465+
*args,
466+
stderr=asyncio.subprocess.STDOUT,
467+
stdout=asyncio.subprocess.PIPE,
468+
limit=1024,
469+
env=_server_env(request),
470+
)
471+
stdout = proc.stdout
472+
assert stdout is not None # noqa: S101
473+
try:
474+
for _ in range(100):
475+
line = await stdout.readline()
476+
if b"Listening at:" in line:
477+
break
478+
port_future.set_result(port)
479+
await proc.wait()
480+
except asyncio.CancelledError:
481+
proc.terminate()
482+
await proc.wait()
483+
484+
485+
async def serve_gunicorn(
486+
request: ServerCompatRequest,
487+
certfile: str | None,
488+
keyfile: str | None,
489+
cafile: str | None,
490+
port_future: asyncio.Future[int],
491+
):
492+
args = ["--bind=127.0.0.1:0", "--workers=8"]
493+
if certfile:
494+
args.append(f"--certfile={certfile}")
495+
if keyfile:
496+
args.append(f"--keyfile={keyfile}")
497+
if cafile:
498+
args.append(f"--ca-certs={cafile}")
499+
args.append(f"--cert-reqs={ssl.CERT_REQUIRED}")
500+
501+
args.append("server:wsgi_app")
502+
503+
proc = await asyncio.create_subprocess_exec(
504+
"gunicorn",
505+
*args,
506+
stderr=asyncio.subprocess.STDOUT,
507+
stdout=asyncio.subprocess.PIPE,
508+
limit=1024,
509+
env=_server_env(request),
510+
)
511+
stdout = proc.stdout
512+
assert stdout is not None # noqa: S101
513+
try:
514+
for _ in range(100):
515+
line = await stdout.readline()
516+
match = _port_regex.match(line.decode("utf-8"))
517+
if match:
518+
port_future.set_result(int(match.group(1)))
519+
break
520+
await proc.wait()
521+
except asyncio.CancelledError:
522+
proc.terminate()
523+
await proc.wait()
524+
525+
526+
async def serve_hypercorn(
527+
request: ServerCompatRequest,
528+
mode: Literal["sync", "async"],
529+
certfile: str | None,
530+
keyfile: str | None,
531+
cafile: str | None,
532+
port_future: asyncio.Future[int],
533+
):
534+
args = ["--bind=localhost:0"]
535+
if certfile:
536+
args.append(f"--certfile={certfile}")
537+
if keyfile:
538+
args.append(f"--keyfile={keyfile}")
539+
if cafile:
540+
args.append(f"--ca-certs={cafile}")
541+
args.append("--verify-mode=CERT_REQUIRED")
542+
543+
if mode == "sync":
544+
args.append("server:wsgi_app")
545+
else:
546+
args.append("server:asgi_app")
547+
548+
proc = await asyncio.create_subprocess_exec(
549+
"hypercorn",
550+
*args,
551+
stderr=asyncio.subprocess.STDOUT,
552+
stdout=asyncio.subprocess.PIPE,
553+
limit=1024,
554+
env=_server_env(request),
555+
)
556+
stdout = proc.stdout
557+
assert stdout is not None # noqa: S101
558+
try:
559+
for _ in range(100):
560+
line = await stdout.readline()
561+
match = _port_regex.match(line.decode("utf-8"))
562+
if match:
563+
port_future.set_result(int(match.group(1)))
564+
break
565+
await proc.wait()
566+
except asyncio.CancelledError:
567+
proc.terminate()
568+
await proc.wait()
569+
570+
571+
async def serve_uvicorn(
572+
request: ServerCompatRequest,
573+
certfile: str | None,
574+
keyfile: str | None,
575+
cafile: str | None,
576+
port_future: asyncio.Future[int],
577+
):
578+
args = ["--port=0", "--no-access-log"]
579+
if certfile:
580+
args.append(f"--ssl-certfile={certfile}")
581+
if keyfile:
582+
args.append(f"--ssl-keyfile={keyfile}")
583+
if cafile:
584+
args.append(f"--ssl-ca-certs={cafile}")
585+
args.append(f"--ssl-cert-reqs={ssl.CERT_REQUIRED}")
586+
587+
args.append("server:asgi_app")
588+
589+
proc = await asyncio.create_subprocess_exec(
590+
"uvicorn",
591+
*args,
592+
stderr=asyncio.subprocess.STDOUT,
593+
stdout=asyncio.subprocess.PIPE,
594+
limit=1024,
595+
env=_server_env(request),
596+
)
597+
stdout = proc.stdout
598+
assert stdout is not None # noqa: S101
599+
try:
600+
for _ in range(100):
601+
line = await stdout.readline()
602+
match = _port_regex.match(line.decode("utf-8"))
603+
if match:
604+
port_future.set_result(int(match.group(1)))
605+
break
606+
await proc.wait()
607+
except asyncio.CancelledError:
608+
proc.terminate()
609+
await proc.wait()
446610

447-
loop = asyncio.get_event_loop()
448-
loop.add_signal_handler(signal.SIGTERM, _signal_handler)
449-
loop.add_signal_handler(signal.SIGINT, _signal_handler)
450611

451-
serve_task = loop.create_task(
452-
hypercorn_serve(
453-
app, # pyright:ignore[reportArgumentType] - some incompatibility in type
454-
conf,
455-
shutdown_trigger=shutdown_event.wait,
456-
mode="asgi" if mode == "async" else "wsgi",
457-
)
458-
)
459-
port = -1
460-
for _ in range(100):
461-
port = conf._log.port
462-
if port != -1:
463-
break
464-
await asyncio.sleep(0.01)
465-
return serve_task, port
612+
def _find_free_port():
613+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
614+
s.bind(("", 0))
615+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
616+
return s.getsockname()[1]
466617

467618

468619
class Args(argparse.Namespace):
469620
mode: Literal["sync", "async"]
621+
server: Literal["granian", "hypercorn", "uvicorn"]
470622

471623

472624
async def main() -> None:
473-
parser = argparse.ArgumentParser(description="Conformance client")
625+
parser = argparse.ArgumentParser(description="Conformance server")
474626
parser.add_argument("--mode", choices=["sync", "async"])
627+
parser.add_argument(
628+
"--server", choices=["granian", "gunicorn", "hypercorn", "uvicorn"]
629+
)
475630
args = parser.parse_args(namespace=Args())
476631

477632
stdin, stdout = await create_standard_streams()
@@ -485,19 +640,67 @@ async def main() -> None:
485640
request = ServerCompatRequest()
486641
request.ParseFromString(request_buf)
487642

488-
serve_task, port = await serve(request, args.mode)
489-
response = ServerCompatResponse()
490-
response.host = "127.0.0.1"
491-
response.port = port
643+
cleanup = ExitStack()
644+
certfile = None
645+
keyfile = None
646+
cafile = None
492647
if request.use_tls:
493-
response.pem_cert = request.server_creds.cert
494-
response_buf = response.SerializeToString()
495-
size_buf = len(response_buf).to_bytes(4, byteorder="big")
496-
stdout.write(size_buf)
497-
stdout.write(response_buf)
498-
await stdout.drain()
499-
# Runner will send sigterm which is handled by serve
500-
await serve_task
648+
cert_file = cleanup.enter_context(NamedTemporaryFile())
649+
key_file = cleanup.enter_context(NamedTemporaryFile())
650+
cert_file.write(request.server_creds.cert)
651+
cert_file.flush()
652+
key_file.write(request.server_creds.key)
653+
key_file.flush()
654+
certfile = cert_file.name
655+
keyfile = key_file.name
656+
if request.client_tls_cert:
657+
ca_cert_file = cleanup.enter_context(NamedTemporaryFile())
658+
ca_cert_file.write(request.client_tls_cert)
659+
ca_cert_file.flush()
660+
cafile = ca_cert_file.name
661+
662+
with cleanup:
663+
port_future: asyncio.Future[int] = asyncio.get_event_loop().create_future()
664+
match args.server:
665+
case "granian":
666+
serve_task = asyncio.create_task(
667+
serve_granian(
668+
request, args.mode, certfile, keyfile, cafile, port_future
669+
)
670+
)
671+
case "gunicorn":
672+
if args.mode == "async":
673+
msg = "gunicorn does not support async mode"
674+
raise ValueError(msg)
675+
serve_task = asyncio.create_task(
676+
serve_gunicorn(request, certfile, keyfile, cafile, port_future)
677+
)
678+
case "hypercorn":
679+
serve_task = asyncio.create_task(
680+
serve_hypercorn(
681+
request, args.mode, certfile, keyfile, cafile, port_future
682+
)
683+
)
684+
case "uvicorn":
685+
if args.mode == "sync":
686+
msg = "uvicorn does not support sync mode"
687+
raise ValueError(msg)
688+
serve_task = asyncio.create_task(
689+
serve_uvicorn(request, certfile, keyfile, cafile, port_future)
690+
)
691+
port = await port_future
692+
response = ServerCompatResponse()
693+
response.host = "127.0.0.1"
694+
response.port = port
695+
if request.use_tls:
696+
response.pem_cert = request.server_creds.cert
697+
response_buf = response.SerializeToString()
698+
size_buf = len(response_buf).to_bytes(4, byteorder="big")
699+
stdout.write(size_buf)
700+
stdout.write(response_buf)
701+
await stdout.drain()
702+
asyncio.get_event_loop().add_signal_handler(signal.SIGTERM, serve_task.cancel)
703+
await serve_task
501704

502705

503706
if __name__ == "__main__":

0 commit comments

Comments
 (0)