Skip to content

Commit 495601d

Browse files
authored
Merge branch 'staging' into feat/thewhaleking/add-neuron-certificates-query-map
2 parents 7c144a4 + baac41b commit 495601d

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

tests/unit_tests/test_axon.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1+
import asyncio
2+
import contextlib
13
import re
4+
import threading
25
import time
36
from dataclasses import dataclass
47
from typing import Any, Optional, Tuple
58
from unittest import IsolatedAsyncioTestCase
69
from unittest.mock import AsyncMock, MagicMock, patch
710

11+
import aiohttp
812
import fastapi
913
import netaddr
1014
import pydantic
1115
import pytest
16+
import uvicorn
1217
from fastapi.testclient import TestClient
1318
from starlette.requests import Request
1419

15-
from bittensor.core.axon import AxonMiddleware, Axon
20+
from bittensor.core.axon import Axon, AxonMiddleware, FastAPIThreadedServer
1621
from bittensor.core.errors import RunException
1722
from bittensor.core.settings import version_as_int
1823
from bittensor.core.stream import StreamingSynapse
@@ -765,3 +770,50 @@ async def forward_fn(synapse: streaming_synapse_cls):
765770
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
766771
},
767772
)
773+
774+
775+
@pytest.mark.asyncio
776+
async def test_threaded_fastapi():
777+
server_started = threading.Event()
778+
server_stopped = threading.Event()
779+
780+
@contextlib.asynccontextmanager
781+
async def lifespan(app):
782+
server_started.set()
783+
yield
784+
server_stopped.set()
785+
786+
app = fastapi.FastAPI(
787+
lifespan=lifespan,
788+
)
789+
app.get("/")(lambda: "Hello World")
790+
791+
server = FastAPIThreadedServer(
792+
uvicorn.Config(app, loop="none"),
793+
)
794+
server.start()
795+
796+
server_started.wait(3.0)
797+
798+
async def wait_for_server():
799+
while not (server.started or server_stopped.is_set()):
800+
await asyncio.sleep(1.0)
801+
802+
await asyncio.wait_for(wait_for_server(), 7.0)
803+
804+
assert server.is_running is True
805+
806+
async with aiohttp.ClientSession(
807+
base_url="http://127.0.0.1:8000",
808+
) as session:
809+
async with session.get("/") as response:
810+
assert await response.text() == '"Hello World"'
811+
812+
server.stop()
813+
814+
assert server.should_exit is True
815+
816+
server_stopped.wait()
817+
818+
with pytest.raises(aiohttp.ClientConnectorError):
819+
await session.get("/")

0 commit comments

Comments
 (0)