diff --git a/pyproject.toml b/pyproject.toml index 65a237b2..c936b042 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,11 +44,13 @@ dependencies = [ "requests", "resolvelib", "requests-mock", + "starlette", "stevedore", "tomlkit", "tqdm", "wheel", "uv>=0.8.19", + "uvicorn", ] [project.optional-dependencies] diff --git a/src/fromager/commands/server.py b/src/fromager/commands/server.py index 37d17543..d55cdf40 100644 --- a/src/fromager/commands/server.py +++ b/src/fromager/commands/server.py @@ -27,10 +27,10 @@ def wheel_server( ) -> None: "Start a web server to serve the local wheels-repo" server.update_wheel_mirror(wkctx) - t = server.run_wheel_server( + _, _, thread = server.run_wheel_server( wkctx, address=address, port=port, ) print(f"Listening on {wkctx.wheel_server_url}") - t.join() + thread.join() diff --git a/src/fromager/server.py b/src/fromager/server.py index d0e86c47..a529e601 100644 --- a/src/fromager/server.py +++ b/src/fromager/server.py @@ -1,15 +1,24 @@ from __future__ import annotations -import functools -import http.server +import asyncio import logging import os import pathlib import shutil +import socket +import stat +import textwrap import threading import typing +from urllib.parse import quote +import uvicorn from packaging.utils import parse_wheel_filename +from starlette.applications import Starlette +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import FileResponse, HTMLResponse, RedirectResponse, Response +from starlette.routing import Route from .threading_utils import with_thread_lock @@ -19,11 +28,6 @@ logger = logging.getLogger(__name__) -class LoggingHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): - def log_message(self, format: str, *args: typing.Any) -> None: - logger.debug(format, *args) - - def start_wheel_server(ctx: context.WorkContext) -> None: update_wheel_mirror(ctx) if ctx.wheel_server_url: @@ -34,33 +38,24 @@ def start_wheel_server(ctx: context.WorkContext) -> None: def run_wheel_server( ctx: context.WorkContext, - address: str = "localhost", + address: str = "127.0.0.1", port: int = 0, -) -> threading.Thread: - server = http.server.ThreadingHTTPServer( - (address, port), - functools.partial(LoggingHTTPRequestHandler, directory=str(ctx.wheels_repo)), - bind_and_activate=False, +) -> tuple[uvicorn.Server, socket.socket, threading.Thread]: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + + app = make_app(ctx.wheel_server_dir) + server, sock, thread = _run_background_thread( + loop=loop, app=app, host=address, port=port ) - server.timeout = 0.5 - server.allow_reuse_address = True - - logger.debug(f"address {server.server_address}") - server.server_bind() - ctx.wheel_server_url = f"http://{address}:{server.server_port}/simple/" - logger.debug("starting wheel server at %s", ctx.wheel_server_url) - server.server_activate() + realport = sock.getsockname()[1] + ctx.wheel_server_url = f"http://{address}:{realport}/simple/" - def serve_forever(server: http.server.ThreadingHTTPServer) -> None: - # ensure server.server_close() is called - with server: - server.serve_forever() - - t = threading.Thread(target=serve_forever, args=(server,)) - t.setDaemon(True) - t.start() - return t + logger.info("started wheel server at %s", ctx.wheel_server_url) + return server, sock, thread @with_thread_lock() @@ -92,3 +87,149 @@ def update_wheel_mirror(ctx: context.WorkContext) -> None: logger.debug("linking %s -> %s into local index", wheel.name, relpath) simple_dest_filename.parent.mkdir(parents=True, exist_ok=True) simple_dest_filename.symlink_to(relpath) + + +class SimpleHTMLIndex: + """Simple HTML Repository API (1.0) + + https://packaging.python.org/en/latest/specifications/simple-repository-api/ + """ + + html_index = textwrap.dedent( + """\ + + + + + Simple index + + + {entries} + + + """ + ) + + html_project = textwrap.dedent( + """\ + + + + + Links for {project} + + +

Links for {project}

+ {entries} + + + """ + ) + + def __init__(self, basedir: pathlib.Path) -> None: + self.basedir = basedir.resolve() + + def _as_anchor(self, prefix: str, direntry: os.DirEntry) -> str: + quoted = quote(direntry.name) + return f'{quoted}
' + + async def root(self, request: Request) -> Response: + return RedirectResponse(url="/simple") + + async def index_page(self, request: Request) -> Response: + prefix = "/simple" + try: + dirs = [ + self._as_anchor(prefix, direntry) + for direntry in os.scandir(self.basedir) + if direntry.is_dir(follow_symlinks=False) + ] + except FileNotFoundError: + raise HTTPException( + status_code=404, detail=f"'{self.basedir}' missing" + ) from None + + content = self.html_index.format(entries="\n".join(dirs)) + return HTMLResponse(content=content) + + async def project_page(self, request: Request) -> Response: + project = request.path_params["project"] + project_dir = self.basedir / project + prefix = f"/simple/{project}" + try: + dirs = [ + self._as_anchor(prefix, direntry) + for direntry in os.scandir(project_dir) + if direntry.name.endswith((".whl", ".whl.metadata", ".tar.gz")) + and direntry.is_file(follow_symlinks=True) + ] + except FileNotFoundError: + raise HTTPException( + status_code=404, detail=f"'{project_dir}' missing" + ) from None + content = self.html_project.format( + project=quote(project), entries="\n".join(dirs) + ) + return HTMLResponse(content=content) + + async def server_file(self, request: Request) -> Response: + project = request.path_params["project"] + filename = request.path_params["filename"] + + path: pathlib.Path = self.basedir / project / filename + try: + stat_result = path.stat(follow_symlinks=True) + except FileNotFoundError: + raise HTTPException(status_code=404, detail="File not found") from None + if not stat.S_ISREG(stat_result.st_mode): + raise HTTPException(status_code=400, detail="Not a regular file") + + if filename.endswith(".tar.gz"): + media_type = "application/x-tar" + elif filename.endswith(".whl"): + media_type = "application/zip" + elif filename.endswith(".whl.metadata"): + media_type = "binary/octet-stream" + else: + raise HTTPException(status_code=400, detail="Bad request") + + return FileResponse(path, media_type=media_type, stat_result=stat_result) + + +def make_app(basedir: pathlib.Path) -> Starlette: + """Create a Starlette app with routing""" + si = SimpleHTMLIndex(basedir) + routes: list[Route] = [ + Route("/", endpoint=si.root), + Route("/simple", endpoint=si.index_page), + Route("/simple/{project:str}", endpoint=si.project_page), + Route("/simple/{project:str}/{filename:str}", endpoint=si.server_file), + ] + return Starlette(routes=routes) + + +def _run_background_thread( + loop: asyncio.AbstractEventLoop, + app: Starlette, + host="127.0.0.1", + port=0, + **kwargs, +) -> tuple[uvicorn.Server, socket.socket, threading.Thread]: + """Run uvicorn server in a daemon thread""" + config = uvicorn.Config(app=app, host=host, port=port, **kwargs) + server = uvicorn.Server(config=config) + sock = server.config.bind_socket() + + def _run_background() -> None: + asyncio.set_event_loop(loop) + loop.run_until_complete(server.serve(sockets=[sock])) + + thread = threading.Thread(target=_run_background, args=(), daemon=True) + thread.start() + return server, sock, thread + + +def stop_server(server: uvicorn.Server, loop: asyncio.AbstractEventLoop) -> None: + """Stop server, blocks until server is shut down""" + fut = asyncio.run_coroutine_threadsafe(server.shutdown(), loop=loop) + fut.result()