|
17 | 17 |
|
18 | 18 |
|
19 | 19 | import re
|
| 20 | +import threading |
20 | 21 | import time
|
21 | 22 | from dataclasses import dataclass
|
22 | 23 | from typing import Any, Optional, Tuple
|
23 | 24 | from unittest import IsolatedAsyncioTestCase
|
24 | 25 | from unittest.mock import AsyncMock, MagicMock, patch
|
25 | 26 |
|
| 27 | +import aiohttp |
26 | 28 | import fastapi
|
27 | 29 | import netaddr
|
28 | 30 | import pydantic
|
29 | 31 | import pytest
|
| 32 | +import uvicorn |
30 | 33 | from fastapi.testclient import TestClient
|
31 | 34 | from starlette.requests import Request
|
32 | 35 |
|
33 |
| -from bittensor.core.axon import AxonMiddleware, Axon |
| 36 | +from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer |
34 | 37 | from bittensor.core.errors import RunException
|
35 | 38 | from bittensor.core.settings import version_as_int
|
36 | 39 | from bittensor.core.stream import StreamingSynapse
|
@@ -785,3 +788,45 @@ async def forward_fn(synapse: streaming_synapse_cls):
|
785 | 788 | "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
|
786 | 789 | },
|
787 | 790 | )
|
| 791 | + |
| 792 | + |
| 793 | +@pytest.mark.asyncio |
| 794 | +async def test_threaded_fastapi(): |
| 795 | + server_started = threading.Event() |
| 796 | + server_stopped = threading.Event() |
| 797 | + |
| 798 | + async def lifespan(app): |
| 799 | + server_started.set() |
| 800 | + yield |
| 801 | + server_stopped.set() |
| 802 | + |
| 803 | + app = fastapi.FastAPI( |
| 804 | + lifespan=lifespan, |
| 805 | + ) |
| 806 | + app.get("/")(lambda: "Hello World") |
| 807 | + |
| 808 | + server = FastAPIThreadedServer( |
| 809 | + uvicorn.Config( |
| 810 | + app, |
| 811 | + ), |
| 812 | + ) |
| 813 | + server.start() |
| 814 | + |
| 815 | + server_started.wait() |
| 816 | + |
| 817 | + assert server.is_running is True |
| 818 | + |
| 819 | + async with aiohttp.ClientSession( |
| 820 | + base_url="http://127.0.0.1:8000", |
| 821 | + ) as session: |
| 822 | + async with session.get("/") as response: |
| 823 | + assert await response.text() == '"Hello World"' |
| 824 | + |
| 825 | + server.stop() |
| 826 | + |
| 827 | + assert server.should_exit is True |
| 828 | + |
| 829 | + server_stopped.wait() |
| 830 | + |
| 831 | + with pytest.raises(aiohttp.ClientConnectorError): |
| 832 | + await session.get("/") |
0 commit comments