Skip to content

Commit fdc2edb

Browse files
shams858Shamsul Arefincrivetimihai
authored
feat: Add tool_pre_invoke and tool_post_invoke plugin hooks (#686)
* feat: add tool hooks to plugin system Signed-off-by: Shamsul Arefin <[email protected]> * feat: connect plugins to tool invoke, pii_filter improvements Signed-off-by: Shamsul Arefin <[email protected]> * Add PluginViolationError: If plugin blocks tool invocation. to the docstring's Raises section. Signed-off-by: Shamsul Arefin <[email protected]> * Apply linter fixes * Rebase and lint Signed-off-by: Mihai Criveti <[email protected]> * Update logging and docs Signed-off-by: Mihai Criveti <[email protected]> --------- Signed-off-by: Shamsul Arefin <[email protected]> Signed-off-by: Mihai Criveti <[email protected]> Co-authored-by: Shamsul Arefin <[email protected]> Co-authored-by: Mihai Criveti <[email protected]>
1 parent 19d3945 commit fdc2edb

File tree

14 files changed

+1033
-16
lines changed

14 files changed

+1033
-16
lines changed

docs/docs/using/plugins/index.md

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ are defined as follows:
9494
| **description** | The description of the plugin configuration. | A plugin for replacing bad words. |
9595
| **version** | The version of the plugin configuration. | 0.1 |
9696
| **author** | The team that wrote the plugin. | MCP Context Forge |
97-
| **hooks** | A list of hooks for which the plugin will be executed. **Note**: currently supports two hooks: "prompt_pre_fetch", "prompt_post_fetch" | ["prompt_pre_fetch", "prompt_post_fetch"] |
97+
| **hooks** | A list of hooks for which the plugin will be executed. Supported hooks: "prompt_pre_fetch", "prompt_post_fetch", "tool_pre_invoke", "tool_post_invoke" | ["prompt_pre_fetch", "prompt_post_fetch", "tool_pre_invoke", "tool_post_invoke"] |
9898
| **tags** | Descriptive keywords that make the configuration searchable. | ["security", "filter"] |
9999
| **mode** | Mode of operation of the plugin. - enforce (stops during a violation), permissive (audits a violation but doesn't stop), disabled (disabled) | permissive |
100100
| **priority** | The priority in which the plugin will run - 0 is higher priority | 100 |
@@ -163,10 +163,25 @@ Currently implemented hooks:
163163
|------|-------------|-----------|
164164
| `prompt_pre_fetch` | Before prompt retrieval | Validate/modify prompt arguments |
165165
| `prompt_post_fetch` | After prompt rendering | Filter/transform rendered prompts |
166+
| `tool_pre_invoke` | Before tool invocation | Validate/modify tool arguments, block dangerous operations |
167+
| `tool_post_invoke` | After tool execution | Filter/transform tool results, audit tool usage |
168+
169+
### Tool Hooks Details
170+
171+
The tool hooks enable plugins to intercept and modify tool invocations:
172+
173+
- **`tool_pre_invoke`**: Receives the tool name and arguments before execution. Can modify arguments or block the invocation entirely.
174+
- **`tool_post_invoke`**: Receives the tool result after execution. Can modify the result or block it from being returned.
175+
176+
Example use cases:
177+
- PII detection and masking in tool inputs/outputs
178+
- Rate limiting specific tools
179+
- Audit logging of tool usage
180+
- Input validation and sanitization
181+
- Output filtering and transformation
166182

167183
Planned hooks (not yet implemented):
168184

169-
- `tool_pre_invoke` / `tool_post_invoke` - Tool execution guardrails
170185
- `resource_pre_fetch` / `resource_post_fetch` - Resource content filtering
171186
- `server_pre_register` / `server_post_register` - Server validation
172187
- `auth_pre_check` / `auth_post_check` - Custom authentication
@@ -179,12 +194,16 @@ Planned hooks (not yet implemented):
179194
```python
180195
from mcpgateway.plugins.framework.base import Plugin
181196
from mcpgateway.plugins.framework.models import PluginConfig
182-
from mcpgateway.plugins.framework.types import (
197+
from mcpgateway.plugins.framework.plugin_types import (
183198
PluginContext,
184199
PromptPrehookPayload,
185200
PromptPrehookResult,
186201
PromptPosthookPayload,
187-
PromptPosthookResult
202+
PromptPosthookResult,
203+
ToolPreInvokePayload,
204+
ToolPreInvokeResult,
205+
ToolPostInvokePayload,
206+
ToolPostInvokeResult
188207
)
189208
190209
class MyPlugin(Plugin):
@@ -251,6 +270,62 @@ class MyPlugin(Plugin):
251270
modified_payload=payload
252271
)
253272
273+
async def tool_pre_invoke(
274+
self,
275+
payload: ToolPreInvokePayload,
276+
context: PluginContext
277+
) -> ToolPreInvokeResult:
278+
"""Process tool before invocation."""
279+
280+
# Access tool name and arguments
281+
tool_name = payload.name
282+
args = payload.args
283+
284+
# Example: Block dangerous operations
285+
if tool_name == "file_delete" and "system" in str(args):
286+
return ToolPreInvokeResult(
287+
continue_processing=False,
288+
violation=PluginViolation(
289+
plugin_name=self.name,
290+
description="Dangerous operation blocked",
291+
violation_code="DANGEROUS_OP",
292+
details={"tool": tool_name}
293+
)
294+
)
295+
296+
# Example: Modify arguments
297+
if "sanitize_me" in args:
298+
args["sanitize_me"] = self.sanitize_input(args["sanitize_me"])
299+
return ToolPreInvokeResult(
300+
modified_payload=ToolPreInvokePayload(tool_name, args)
301+
)
302+
303+
return ToolPreInvokeResult()
304+
305+
async def tool_post_invoke(
306+
self,
307+
payload: ToolPostInvokePayload,
308+
context: PluginContext
309+
) -> ToolPostInvokeResult:
310+
"""Process tool after invocation."""
311+
312+
# Access tool result
313+
tool_name = payload.name
314+
result = payload.result
315+
316+
# Example: Filter sensitive data from results
317+
if isinstance(result, dict) and "sensitive_data" in result:
318+
result["sensitive_data"] = "[REDACTED]"
319+
return ToolPostInvokeResult(
320+
modified_payload=ToolPostInvokePayload(tool_name, result)
321+
)
322+
323+
# Example: Add audit metadata
324+
context.metadata["tool_executed"] = tool_name
325+
context.metadata["execution_time"] = time.time()
326+
327+
return ToolPostInvokeResult()
328+
254329
async def shutdown(self):
255330
"""Cleanup when plugin shuts down."""
256331
# Close connections, save state, etc.

mcpgateway/db.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,9 +1117,6 @@ class Gateway(Base):
11171117
# Header passthrough configuration
11181118
passthrough_headers: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) # Store list of strings as JSON array
11191119

1120-
# Header passthrough configuration
1121-
passthrough_headers: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True) # Store list of strings as JSON array
1122-
11231120
# Relationship with local tools this gateway provides
11241121
tools: Mapped[List["Tool"]] = relationship(back_populates="gateway", foreign_keys="Tool.gateway_id", cascade="all, delete-orphan")
11251122

mcpgateway/plugins/framework/base.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
PromptPosthookResult,
2828
PromptPrehookPayload,
2929
PromptPrehookResult,
30+
ToolPostInvokePayload,
31+
ToolPostInvokeResult,
32+
ToolPreInvokePayload,
33+
ToolPreInvokeResult,
3034
)
3135

3236

@@ -166,6 +170,46 @@ async def prompt_post_fetch(self, payload: PromptPosthookPayload, context: Plugi
166170
"""
167171
)
168172

173+
async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult:
174+
"""Plugin hook run before a tool is invoked.
175+
176+
Args:
177+
payload: The tool payload to be analyzed.
178+
context: Contextual information about the hook call.
179+
180+
Returns:
181+
ToolPreInvokeResult with processing status and modified payload.
182+
183+
Examples:
184+
>>> from mcpgateway.plugins.framework.plugin_types import ToolPreInvokePayload, PluginContext, GlobalContext
185+
>>> payload = ToolPreInvokePayload("calculator", {"operation": "add", "a": 5, "b": 3})
186+
>>> context = PluginContext(GlobalContext(request_id="123"))
187+
>>> # In async context:
188+
>>> # result = await plugin.tool_pre_invoke(payload, context)
189+
"""
190+
# Default pass-through implementation
191+
return ToolPreInvokeResult(continue_processing=True, modified_payload=payload)
192+
193+
async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult:
194+
"""Plugin hook run after a tool is invoked.
195+
196+
Args:
197+
payload: The tool result payload to be analyzed.
198+
context: Contextual information about the hook call.
199+
200+
Returns:
201+
ToolPostInvokeResult with processing status and modified result.
202+
203+
Examples:
204+
>>> from mcpgateway.plugins.framework.plugin_types import ToolPostInvokePayload, PluginContext, GlobalContext
205+
>>> payload = ToolPostInvokePayload("calculator", {"result": 8, "status": "success"})
206+
>>> context = PluginContext(GlobalContext(request_id="123"))
207+
>>> # In async context:
208+
>>> # result = await plugin.tool_post_invoke(payload, context)
209+
"""
210+
# Default pass-through implementation
211+
return ToolPostInvokeResult(continue_processing=True, modified_payload=payload)
212+
169213
async def shutdown(self) -> None:
170214
"""Plugin cleanup code."""
171215

mcpgateway/plugins/framework/manager.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,13 @@
4646
PromptPosthookResult,
4747
PromptPrehookPayload,
4848
PromptPrehookResult,
49+
ToolPostInvokePayload,
50+
ToolPostInvokeResult,
51+
ToolPreInvokePayload,
52+
ToolPreInvokeResult,
4953
)
5054
from mcpgateway.plugins.framework.registry import PluginInstanceRegistry
51-
from mcpgateway.plugins.framework.utils import post_prompt_matches, pre_prompt_matches
55+
from mcpgateway.plugins.framework.utils import post_prompt_matches, post_tool_matches, pre_prompt_matches, pre_tool_matches
5256

5357
logger = logging.getLogger(__name__)
5458

@@ -303,6 +307,54 @@ async def post_prompt_fetch(plugin: PluginRef, payload: PromptPosthookPayload, c
303307
return await plugin.plugin.prompt_post_fetch(payload, context)
304308

305309

310+
async def pre_tool_invoke(plugin: PluginRef, payload: ToolPreInvokePayload, context: PluginContext) -> ToolPreInvokeResult:
311+
"""Call plugin's tool pre-invoke hook.
312+
313+
Args:
314+
plugin: The plugin to execute.
315+
payload: The tool payload to be analyzed.
316+
context: Contextual information about the hook call.
317+
318+
Returns:
319+
The result of the plugin execution.
320+
321+
Examples:
322+
>>> from mcpgateway.plugins.framework.base import Plugin, PluginRef
323+
>>> from mcpgateway.plugins.framework.plugin_types import ToolPreInvokePayload, PluginContext, GlobalContext
324+
>>> # Assuming you have a plugin instance:
325+
>>> # plugin_ref = PluginRef(my_plugin)
326+
>>> payload = ToolPreInvokePayload(name="calculator", args={"operation": "add", "a": 5, "b": 3})
327+
>>> context = PluginContext(GlobalContext(request_id="123"))
328+
>>> # In async context:
329+
>>> # result = await pre_tool_invoke(plugin_ref, payload, context)
330+
"""
331+
return await plugin.plugin.tool_pre_invoke(payload, context)
332+
333+
334+
async def post_tool_invoke(plugin: PluginRef, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult:
335+
"""Call plugin's tool post-invoke hook.
336+
337+
Args:
338+
plugin: The plugin to execute.
339+
payload: The tool result payload to be analyzed.
340+
context: Contextual information about the hook call.
341+
342+
Returns:
343+
The result of the plugin execution.
344+
345+
Examples:
346+
>>> from mcpgateway.plugins.framework.base import Plugin, PluginRef
347+
>>> from mcpgateway.plugins.framework.plugin_types import ToolPostInvokePayload, PluginContext, GlobalContext
348+
>>> # Assuming you have a plugin instance:
349+
>>> # plugin_ref = PluginRef(my_plugin)
350+
>>> payload = ToolPostInvokePayload(name="calculator", result={"result": 8, "status": "success"})
351+
>>> context = PluginContext(GlobalContext(request_id="123"))
352+
>>> # In async context:
353+
>>> # result = await post_tool_invoke(plugin_ref, payload, context)
354+
"""
355+
return await plugin.plugin.tool_post_invoke(payload, context)
356+
357+
306358
class PluginManager:
307359
"""Plugin manager for managing the plugin lifecycle.
308360
@@ -343,6 +395,8 @@ class PluginManager:
343395
_config: Config | None = None
344396
_pre_prompt_executor: PluginExecutor[PromptPrehookPayload] = PluginExecutor[PromptPrehookPayload]()
345397
_post_prompt_executor: PluginExecutor[PromptPosthookPayload] = PluginExecutor[PromptPosthookPayload]()
398+
_pre_tool_executor: PluginExecutor[ToolPreInvokePayload] = PluginExecutor[ToolPreInvokePayload]()
399+
_post_tool_executor: PluginExecutor[ToolPostInvokePayload] = PluginExecutor[ToolPostInvokePayload]()
346400

347401
# Context cleanup tracking
348402
_context_store: Dict[str, Tuple[PluginContextTable, float]] = {}
@@ -369,6 +423,8 @@ def __init__(self, config: str = "", timeout: int = DEFAULT_PLUGIN_TIMEOUT):
369423
# Update executor timeouts
370424
self._pre_prompt_executor.timeout = timeout
371425
self._post_prompt_executor.timeout = timeout
426+
self._pre_tool_executor.timeout = timeout
427+
self._post_tool_executor.timeout = timeout
372428

373429
# Initialize context tracking if not already done
374430
if not hasattr(self, "_context_store"):
@@ -614,3 +670,112 @@ async def prompt_post_fetch(
614670
del self._context_store[global_context.request_id]
615671

616672
return result
673+
674+
async def tool_pre_invoke(
675+
self,
676+
payload: ToolPreInvokePayload,
677+
global_context: GlobalContext,
678+
local_contexts: Optional[PluginContextTable] = None,
679+
) -> tuple[ToolPreInvokeResult, PluginContextTable | None]:
680+
"""Execute pre-invoke hooks before a tool is invoked.
681+
682+
Args:
683+
payload: The tool payload containing name and arguments.
684+
global_context: Shared context for all plugins with request metadata.
685+
local_contexts: Optional existing contexts from previous executions.
686+
687+
Returns:
688+
A tuple containing:
689+
- ToolPreInvokeResult with processing status and modified payload
690+
- PluginContextTable with updated contexts for post-invoke hook
691+
692+
Raises:
693+
PayloadSizeError: If payload exceeds size limits.
694+
695+
Examples:
696+
>>> manager = PluginManager("plugins/config.yaml")
697+
>>> # In async context:
698+
>>> # await manager.initialize()
699+
>>>
700+
>>> from mcpgateway.plugins.framework.plugin_types import ToolPreInvokePayload, GlobalContext
701+
>>> payload = ToolPreInvokePayload(
702+
... name="calculator",
703+
... args={"operation": "add", "a": 5, "b": 3}
704+
... )
705+
>>> context = GlobalContext(
706+
... request_id="req-123",
707+
... user="[email protected]"
708+
... )
709+
>>>
710+
>>> # In async context:
711+
>>> # result, contexts = await manager.tool_pre_invoke(payload, context)
712+
>>> # if result.continue_processing:
713+
>>> # # Proceed with tool invocation
714+
>>> # modified_payload = result.modified_payload or payload
715+
"""
716+
# Cleanup old contexts periodically
717+
await self._cleanup_old_contexts()
718+
719+
# Get plugins configured for this hook
720+
plugins = self._registry.get_plugins_for_hook(HookType.TOOL_PRE_INVOKE)
721+
722+
# Execute plugins
723+
result = await self._pre_tool_executor.execute(plugins, payload, global_context, pre_tool_invoke, pre_tool_matches, local_contexts)
724+
725+
# Store contexts for potential reuse
726+
if result[1]:
727+
self._context_store[global_context.request_id] = (result[1], time.time())
728+
729+
return result
730+
731+
async def tool_post_invoke(
732+
self, payload: ToolPostInvokePayload, global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None
733+
) -> tuple[ToolPostInvokeResult, PluginContextTable | None]:
734+
"""Execute post-invoke hooks after a tool is invoked.
735+
736+
Args:
737+
payload: The tool result payload containing invocation results.
738+
global_context: Shared context for all plugins with request metadata.
739+
local_contexts: Optional contexts from pre-invoke hook execution.
740+
741+
Returns:
742+
A tuple containing:
743+
- ToolPostInvokeResult with processing status and modified result
744+
- PluginContextTable with final contexts
745+
746+
Raises:
747+
PayloadSizeError: If payload exceeds size limits.
748+
749+
Examples:
750+
>>> # Continuing from tool_pre_invoke example
751+
>>> from mcpgateway.plugins.framework.plugin_types import ToolPostInvokePayload, GlobalContext
752+
>>>
753+
>>> post_payload = ToolPostInvokePayload(
754+
... name="calculator",
755+
... result={"result": 8, "status": "success"}
756+
... )
757+
>>>
758+
>>> manager = PluginManager("plugins/config.yaml")
759+
>>> context = GlobalContext(request_id="req-123")
760+
>>>
761+
>>> # In async context:
762+
>>> # result, _ = await manager.tool_post_invoke(
763+
>>> # post_payload,
764+
>>> # context,
765+
>>> # contexts # From pre_invoke
766+
>>> # )
767+
>>> # if result.modified_payload:
768+
>>> # # Use modified result
769+
>>> # final_result = result.modified_payload.result
770+
"""
771+
# Get plugins configured for this hook
772+
plugins = self._registry.get_plugins_for_hook(HookType.TOOL_POST_INVOKE)
773+
774+
# Execute plugins
775+
result = await self._post_tool_executor.execute(plugins, payload, global_context, post_tool_invoke, post_tool_matches, local_contexts)
776+
777+
# Clean up stored context after post-invoke
778+
if global_context.request_id in self._context_store:
779+
del self._context_store[global_context.request_id]
780+
781+
return result

0 commit comments

Comments
 (0)