Skip to content

Commit baac41b

Browse files
authored
Merge pull request #2597 from opentensor/test/use_asynccontextmanager_for_fastapi_lifespan
test: use asynccontextmanager for FastAPI lifespan
2 parents aacf074 + 76142a5 commit baac41b

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

tests/unit_tests/test_axon.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1+
import asyncio
2+
import contextlib
13
import re
4+
import threading
25
import time
36
from dataclasses import dataclass
47
from typing import Any, Optional, Tuple
58
from unittest import IsolatedAsyncioTestCase
69
from unittest.mock import AsyncMock, MagicMock, patch
710

11+
import aiohttp
812
import fastapi
913
import netaddr
1014
import pydantic
1115
import pytest
16+
import uvicorn
1217
from fastapi.testclient import TestClient
1318
from starlette.requests import Request
1419

15-
from bittensor.core.axon import AxonMiddleware, Axon
20+
from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer
1621
from bittensor.core.errors import RunException
1722
from bittensor.core.settings import version_as_int
1823
from bittensor.core.stream import StreamingSynapse
@@ -765,3 +770,50 @@ async def forward_fn(synapse: streaming_synapse_cls):
765770
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
766771
},
767772
)
773+
774+
775+
@pytest.mark.asyncio
776+
async def test_threaded_fastapi():
777+
server_started = threading.Event()
778+
server_stopped = threading.Event()
779+
780+
@contextlib.asynccontextmanager
781+
async def lifespan(app):
782+
server_started.set()
783+
yield
784+
server_stopped.set()
785+
786+
app = fastapi.FastAPI(
787+
lifespan=lifespan,
788+
)
789+
app.get("/")(lambda: "Hello World")
790+
791+
server = FastAPIThreadedServer(
792+
uvicorn.Config(app, loop="none"),
793+
)
794+
server.start()
795+
796+
server_started.wait(3.0)
797+
798+
async def wait_for_server():
799+
while not (server.started or server_stopped.is_set()):
800+
await asyncio.sleep(1.0)
801+
802+
await asyncio.wait_for(wait_for_server(), 7.0)
803+
804+
assert server.is_running is True
805+
806+
async with aiohttp.ClientSession(
807+
base_url="http://127.0.0.1:8000",
808+
) as session:
809+
async with session.get("/") as response:
810+
assert await response.text() == '"Hello World"'
811+
812+
server.stop()
813+
814+
assert server.should_exit is True
815+
816+
server_stopped.wait()
817+
818+
with pytest.raises(aiohttp.ClientConnectorError):
819+
await session.get("/")

0 commit comments

Comments
 (0)