1+ # SPDX-FileCopyrightText: 2024 LangChain, Inc.
2+ # SPDX-License-Identifier: MIT
3+ import time
4+ from functools import wraps
5+
6+ from fastmcp .server .dependencies import get_context
7+ from nc_py_api import NextcloudApp
8+ from fastmcp .server .middleware import Middleware , MiddlewareContext , CallNext
9+ from fastmcp .tools import Tool
10+ from mcp import types as mt
11+ from ex_app .lib .tools import get_tools
12+ import requests
13+
14+ def get_user (authorization_header : str , nc : NextcloudApp ) -> str :
15+ response = requests .get (
16+ f"{ nc .app_cfg .endpoint } /ocs/v2.php/cloud/user" ,
17+ headers = {
18+ "Accept" : "application/json" ,
19+ "Ocs-Apirequest" : "1" ,
20+ "Authorization" : authorization_header ,
21+ },
22+ )
23+ if response .status_code != 200 :
24+ raise Exception ("Failed to get user info" )
25+ return response .json ()["ocs" ]["data" ]["id" ]
26+
27+
28+ class UserAuthMiddleware (Middleware ):
29+ async def on_message (self , context : MiddlewareContext , call_next ):
30+ # Middleware stores user info in context state
31+ authorization_header = context .fastmcp_context .request_context .request .headers .get ("Authorization" )
32+ if authorization_header is None :
33+ raise Exception ("Authorization header is missing/invalid" )
34+ nc = NextcloudApp ()
35+ user = get_user (authorization_header , nc )
36+ nc .set_user (user )
37+ context .fastmcp_context .set_state ("nextcloud" , nc )
38+ return await call_next (context )
39+
40+
41+ LAST_MCP_TOOL_UPDATE = 0
42+
43+
44+ class ToolListMiddleware (Middleware ):
45+ def __init__ (self , mcp ):
46+ self .mcp = mcp
47+
48+ async def on_message (
49+ self ,
50+ context : MiddlewareContext [mt .ListToolsRequest ],
51+ call_next : CallNext [mt .ListToolsRequest , list [Tool ]],
52+ ) -> list [Tool ]:
53+ global LAST_MCP_TOOL_UPDATE
54+ if LAST_MCP_TOOL_UPDATE + 60 < time .time ():
55+ safe , dangerous = await get_tools (context .fastmcp_context .get_state ("nextcloud" ))
56+ tools = await self .mcp .get_tools ()
57+ if LAST_MCP_TOOL_UPDATE + 60 < time .time ():
58+ for tool in tools .keys ():
59+ self .mcp .remove_tool (tool )
60+ for tool in safe + dangerous :
61+ if not hasattr (tool , "func" ) or tool .func is None :
62+ continue
63+ self .mcp .tool ()(mcp_tool (tool .func ))
64+ LAST_MCP_TOOL_UPDATE = time .time ()
65+ return await call_next (context )
66+
67+ # Regenerates the tools with the correct nc object
68+ def mcp_tool (tool ):
69+ @wraps (tool )
70+ async def wrapper (* args , ** kwargs ):
71+ ctx = get_context ()
72+ nc = ctx .get_state ('nextcloud' )
73+ safe , dangerous = await get_tools (nc )
74+ tools = safe + dangerous
75+ for t in tools :
76+ if hasattr (t , "func" ) and t .func and t .name == tool .__name__ :
77+ return t .func (* args , ** kwargs )
78+ raise RuntimeError ("Tool not found" )
79+ return wrapper
0 commit comments