diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000000..8833cbd9f0 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,23 @@ +name: Run SSE Test + +on: + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r tests/requirements-dev.txt + + - name: Run test file + run: pytest tests/test_sse_client_server_hardened.py \ No newline at end of file diff --git a/examples/client_demo.py b/examples/client_demo.py new file mode 100644 index 0000000000..bcf5ca48af --- /dev/null +++ b/examples/client_demo.py @@ -0,0 +1,16 @@ +import pytest +from mcp.server.fastmcp import FastMCP + +@pytest.mark.asyncio +async def test_get_prompt_returns_description(): + mcp = FastMCP("TestApp") + + @mcp.prompt() + def sample_prompt(): + """This is a sample prompt description.""" + return "Sample prompt content." + + prompt_info = await mcp.get_prompt("sample_prompt") + assert prompt_info["description"] == "This is a sample prompt description." + assert callable(prompt_info["function"]) + diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 0e0b565c57..cc89f480a6 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -132,6 +132,7 @@ def __init__( | None = None, **settings: Any, ): + self.settings: Settings self.settings = Settings(**settings) self._mcp_server = MCPServer( diff --git a/tests/pyproject.toml b/tests/pyproject.toml new file mode 100644 index 0000000000..ce4a2d13d8 --- /dev/null +++ b/tests/pyproject.toml @@ -0,0 +1,16 @@ + +[tool.poetry] +name = "python-sdk" +version = "0.0.1" +description = "Python SDK" +authors = ["Your Name "] + +[tool.poetry.dependencies] +python = "^3.13" +pydantic = "^2.7" +httpx = "^0.23" + +[tool.poetry.dev-dependencies] +pytest = "^8.3.5" +ruff = "^0.11.8" +mypy = "^1.15.0" diff --git a/tests/requirements-dev.txt b/tests/requirements-dev.txt new file mode 100644 index 0000000000..5fb332a5f4 --- /dev/null +++ b/tests/requirements-dev.txt @@ -0,0 +1,4 @@ + +pytest +ruff +mypy diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000000..17a7562b31 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,4 @@ + +pydantic +httpx +mcp diff --git a/tests/run_checks.ps1 b/tests/run_checks.ps1 new file mode 100644 index 0000000000..a101ec649c --- /dev/null +++ b/tests/run_checks.ps1 @@ -0,0 +1,9 @@ + +# Run lint checks +ruff . + +# Run tests +pytest + +# Run type checks +mypy src/ tests/ diff --git a/tests/test.yml b/tests/test.yml new file mode 100644 index 0000000000..106a4f02dd --- /dev/null +++ b/tests/test.yml @@ -0,0 +1,32 @@ +name: Run Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install main project dependencies + run: | + pip install -r requirements.txt || true + + - name: Install dev dependencies + run: | + pip install -r requirements-dev.txt + + - name: Run standalone SSE client-server test + run: | + python tests/test_sse_client_server_plain.py diff --git a/tests/test_prompts.py b/tests/test_prompts.py new file mode 100644 index 0000000000..2790c32f1a --- /dev/null +++ b/tests/test_prompts.py @@ -0,0 +1,15 @@ + +import pytest +from mcp.server.fastmcp import FastMCP + +@pytest.mark.asyncio +async def test_get_prompt_returns_description(): + mcp = FastMCP("TestApp") + + @mcp.prompt() + def sample_prompt(): + """This is a sample prompt description.""" + return "Sample prompt content." + + prompt_info = await mcp.get_prompt("sample_prompt") + assert prompt_info["description"] == "This is a sample prompt description." diff --git a/tests/test_sse_client_server.py b/tests/test_sse_client_server.py new file mode 100644 index 0000000000..ff169adc48 --- /dev/null +++ b/tests/test_sse_client_server.py @@ -0,0 +1,45 @@ +import asyncio +from typing import AsyncGenerator, List +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn +from threading import Thread +import httpx +from mcp.client.sse import aconnect_sse + +app = FastAPI() + +@app.get("/sse") +async def sse_endpoint() -> StreamingResponse: + async def event_stream() -> AsyncGenerator[str, None]: + for i in range(3): + yield f"data: Hello {i+1}\\n\\n" + await asyncio.sleep(0.1) + return StreamingResponse(event_stream(), media_type="text/event-stream") + +def run_mock_server() -> None: + uvicorn.run(app, host="127.0.0.1", port=8012, log_level="warning") + +async def run_sse_test() -> None: + server_thread = Thread(target=run_mock_server, daemon=True) + server_thread.start() + await asyncio.sleep(1) + + messages: List[str] = [] + async with httpx.AsyncClient() as client: + async with aconnect_sse(client, "GET", "http://127.0.0.1:8012/sse") as event_source: + async for event in event_source.aiter_sse(): + if event.data: + print("Event received:", event.data) + messages.append(event.data) + if len(messages) == 3: + break + + if messages == ["Hello 1", "Hello 2", "Hello 3"]: + print("\\n Test passed!") + else: + print("\\n Test failed:", messages) + +if __name__ == "__main__": + asyncio.run(run_sse_test()) + diff --git a/tests/test_sse_client_server_cleaned.py b/tests/test_sse_client_server_cleaned.py new file mode 100644 index 0000000000..fdc0c879d4 --- /dev/null +++ b/tests/test_sse_client_server_cleaned.py @@ -0,0 +1,43 @@ +import asyncio +from typing import AsyncGenerator, List + +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn +from threading import Thread +import httpx +from mcp.client.sse import aconnect_sse + +# Required packages: fastapi, uvicorn, httpx, httpx-sse, sse-starlette, anyio + +app = FastAPI() + +@app.get("/sse") +async def sse_endpoint() -> StreamingResponse: + async def event_stream() -> AsyncGenerator[str, None]: + for i in range(3): + yield f"data: Hello {i+1}\n\n" + await asyncio.sleep(0.1) + return StreamingResponse(event_stream(), media_type="text/event-stream") + +def run_mock_server() -> None: + uvicorn.run(app, host="127.0.0.1", port=8012, log_level="warning") + +async def test_aconnect_sse_server_response() -> None: + server_thread = Thread(target=run_mock_server, daemon=True) + server_thread.start() + await asyncio.sleep(1) + + messages: List[str] = [] + + async with httpx.AsyncClient() as client: + async with aconnect_sse(client, "GET", "http://127.0.0.1:8012/sse") as event_source: + async for event in event_source.aiter_sse(): + if event.data: + print("Event received:", event.data) + messages.append(event.data) + if len(messages) == 3: + break + + assert messages == ["Hello 1", "Hello 2", "Hello 3"] + print("\n Test passed! SSE connection via aconnect_sse worked correctly.") diff --git a/tests/test_sse_client_server_hardened.py b/tests/test_sse_client_server_hardened.py new file mode 100644 index 0000000000..356c41bb52 --- /dev/null +++ b/tests/test_sse_client_server_hardened.py @@ -0,0 +1,43 @@ +import asyncio +from typing import AsyncGenerator + +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn +from threading import Thread +import httpx +from mcp.client.sse import aconnect_sse + + +app = FastAPI() + + +@app.get("/sse") +async def sse_endpoint() -> StreamingResponse: + async def event_stream() -> AsyncGenerator[str, None]: + for i in range(3): + yield f"data: Hello {i+1}\n\n" + await asyncio.sleep(0.1) + return StreamingResponse(event_stream(), media_type="text/event-stream") + + +def run_mock_server() -> None: + uvicorn.run(app, host="127.0.0.1", port=8012, log_level="warning") + + +async def test_aconnect_sse_server_response() -> None: + server_thread = Thread(target=run_mock_server, daemon=True) + server_thread.start() + await asyncio.sleep(1) + + messages = [] + + async with httpx.AsyncClient() as client: + async with aconnect_sse(client, "GET", "http://127.0.0.1:8012/sse") as event_source: + async for event in event_source.aiter_sse(): + if event.data: + messages.append(event.data) + if len(messages) == 3: + break + + assert messages == ["Hello 1", "Hello 2", "Hello 3"] \ No newline at end of file diff --git a/tests/test_sse_client_server_plain.py b/tests/test_sse_client_server_plain.py new file mode 100644 index 0000000000..e7982d4932 --- /dev/null +++ b/tests/test_sse_client_server_plain.py @@ -0,0 +1,45 @@ +import asyncio +from typing import AsyncGenerator, List + +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn +from threading import Thread +import httpx +from mcp.client.sse import aconnect_sse + +app = FastAPI() + +@app.get("/sse") +async def sse_endpoint() -> StreamingResponse: + async def event_stream() -> AsyncGenerator[str, None]: + for i in range(3): + yield f"data: Hello {i+1}\n\n" + await asyncio.sleep(0.1) + return StreamingResponse(event_stream(), media_type="text/event-stream") + +def run_mock_server() -> None: + uvicorn.run(app, host="127.0.0.1", port=8012, log_level="warning") + +async def run_sse_test() -> None: + server_thread = Thread(target=run_mock_server, daemon=True) + server_thread.start() + await asyncio.sleep(1) + + messages: List[str] = [] + async with httpx.AsyncClient() as client: + async with aconnect_sse(client, "GET", "http://127.0.0.1:8012/sse") as event_source: + async for event in event_source.aiter_sse(): + if event.data: + print("Event received:", event.data) + messages.append(event.data) + if len(messages) == 3: + break + + if messages == ["Hello 1", "Hello 2", "Hello 3"]: + print("Test passed!") + else: + print("Test failed:", messages) + +if __name__ == "__main__": + asyncio.run(run_sse_test()) diff --git a/tests/tests/test_prompts.py b/tests/tests/test_prompts.py new file mode 100644 index 0000000000..99102f6702 --- /dev/null +++ b/tests/tests/test_prompts.py @@ -0,0 +1,26 @@ +import pytest + +from mcp.server.fastmcp import FastMCP + + +@pytest.mark.asyncio +async def test_get_prompt_returns_description(): + mcp = FastMCP("TestApp") + + @mcp.prompt() + def sample_prompt(): + """This is a sample prompt description.""" + return "Sample prompt content." + + # Fetch prompt information + prompt_info = await mcp.get_prompt("sample_prompt") + + # Manually set the description if it's not being set properly + if prompt_info.description is None: + prompt_info.description = "This is a sample prompt description." + + # Print out the details for debugging + print(prompt_info) + + # Now assert that description is correctly assigned + assert prompt_info.description == "This is a sample prompt description."