|
| 1 | +import asyncio |
| 2 | +import contextlib |
1 | 3 | import re
|
| 4 | +import threading |
2 | 5 | import time
|
3 | 6 | from dataclasses import dataclass
|
4 | 7 | from typing import Any, Optional, Tuple
|
5 | 8 | from unittest import IsolatedAsyncioTestCase
|
6 | 9 | from unittest.mock import AsyncMock, MagicMock, patch
|
7 | 10 |
|
| 11 | +import aiohttp |
8 | 12 | import fastapi
|
9 | 13 | import netaddr
|
10 | 14 | import pydantic
|
11 | 15 | import pytest
|
| 16 | +import uvicorn |
12 | 17 | from fastapi.testclient import TestClient
|
13 | 18 | from starlette.requests import Request
|
14 | 19 |
|
15 |
| -from bittensor.core.axon import AxonMiddleware, Axon |
| 20 | +from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer |
16 | 21 | from bittensor.core.errors import RunException
|
17 | 22 | from bittensor.core.settings import version_as_int
|
18 | 23 | from bittensor.core.stream import StreamingSynapse
|
@@ -774,3 +779,50 @@ async def forward_fn(synapse: streaming_synapse_cls):
|
774 | 779 | "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
|
775 | 780 | },
|
776 | 781 | )
|
| 782 | + |
| 783 | + |
| 784 | +@pytest.mark.asyncio |
| 785 | +async def test_threaded_fastapi(): |
| 786 | + server_started = threading.Event() |
| 787 | + server_stopped = threading.Event() |
| 788 | + |
| 789 | + @contextlib.asynccontextmanager |
| 790 | + async def lifespan(app): |
| 791 | + server_started.set() |
| 792 | + yield |
| 793 | + server_stopped.set() |
| 794 | + |
| 795 | + app = fastapi.FastAPI( |
| 796 | + lifespan=lifespan, |
| 797 | + ) |
| 798 | + app.get("/")(lambda: "Hello World") |
| 799 | + |
| 800 | + server = FastAPIThreadedServer( |
| 801 | + uvicorn.Config(app, loop="none"), |
| 802 | + ) |
| 803 | + server.start() |
| 804 | + |
| 805 | + server_started.wait(3.0) |
| 806 | + |
| 807 | + async def wait_for_server(): |
| 808 | + while not (server.started or server_stopped.is_set()): |
| 809 | + await asyncio.sleep(1.0) |
| 810 | + |
| 811 | + await asyncio.wait_for(wait_for_server(), 7.0) |
| 812 | + |
| 813 | + assert server.is_running is True |
| 814 | + |
| 815 | + async with aiohttp.ClientSession( |
| 816 | + base_url="http://127.0.0.1:8000", |
| 817 | + ) as session: |
| 818 | + async with session.get("/") as response: |
| 819 | + assert await response.text() == '"Hello World"' |
| 820 | + |
| 821 | + server.stop() |
| 822 | + |
| 823 | + assert server.should_exit is True |
| 824 | + |
| 825 | + server_stopped.wait() |
| 826 | + |
| 827 | + with pytest.raises(aiohttp.ClientConnectorError): |
| 828 | + await session.get("/") |
0 commit comments