Skip to content

Commit 7c7349c

Browse files
committed
Doctest for plugins
Signed-off-by: Mihai Criveti <[email protected]>
1 parent a4742bd commit 7c7349c

File tree

7 files changed

+376
-11
lines changed

7 files changed

+376
-11
lines changed

mcpgateway/plugins/framework/base.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,52 @@
3131

3232

3333
class Plugin:
34-
"""Base plugin object for pre/post processing of inputs and outputs at various locations throughout the server."""
34+
"""Base plugin object for pre/post processing of inputs and outputs at various locations throughout the server.
35+
36+
Examples:
37+
>>> from mcpgateway.plugins.framework.models import PluginConfig, HookType, PluginMode
38+
>>> config = PluginConfig(
39+
... name="test_plugin",
40+
... description="Test plugin",
41+
... author="test",
42+
... kind="mcpgateway.plugins.framework.base.Plugin",
43+
... version="1.0.0",
44+
... hooks=[HookType.PROMPT_PRE_FETCH],
45+
... tags=["test"],
46+
... mode=PluginMode.ENFORCE,
47+
... priority=50
48+
... )
49+
>>> plugin = Plugin(config)
50+
>>> plugin.name
51+
'test_plugin'
52+
>>> plugin.priority
53+
50
54+
>>> plugin.mode
55+
<PluginMode.ENFORCE: 'enforce'>
56+
>>> HookType.PROMPT_PRE_FETCH in plugin.hooks
57+
True
58+
"""
3559

3660
def __init__(self, config: PluginConfig) -> None:
3761
"""Initialize a plugin with a configuration and context.
3862
3963
Args:
4064
config: The plugin configuration
65+
66+
Examples:
67+
>>> from mcpgateway.plugins.framework.models import PluginConfig, HookType
68+
>>> config = PluginConfig(
69+
... name="simple_plugin",
70+
... description="Simple test",
71+
... author="test",
72+
... kind="test.Plugin",
73+
... version="1.0.0",
74+
... hooks=[HookType.PROMPT_POST_FETCH],
75+
... tags=["simple"]
76+
... )
77+
>>> plugin = Plugin(config)
78+
>>> plugin._config.name
79+
'simple_plugin'
4180
"""
4281
self._config = config
4382

@@ -132,13 +171,58 @@ async def shutdown(self) -> None:
132171

133172

134173
class PluginRef:
135-
"""Plugin reference which contains a uuid."""
174+
"""Plugin reference which contains a uuid.
175+
176+
Examples:
177+
>>> from mcpgateway.plugins.framework.models import PluginConfig, HookType, PluginMode
178+
>>> config = PluginConfig(
179+
... name="ref_test",
180+
... description="Reference test",
181+
... author="test",
182+
... kind="test.Plugin",
183+
... version="1.0.0",
184+
... hooks=[HookType.PROMPT_PRE_FETCH],
185+
... tags=["ref", "test"],
186+
... mode=PluginMode.PERMISSIVE,
187+
... priority=100
188+
... )
189+
>>> plugin = Plugin(config)
190+
>>> ref = PluginRef(plugin)
191+
>>> ref.name
192+
'ref_test'
193+
>>> ref.priority
194+
100
195+
>>> ref.mode
196+
<PluginMode.PERMISSIVE: 'permissive'>
197+
>>> len(ref.uuid) # UUID is a 32-character hex string
198+
32
199+
>>> ref.tags
200+
['ref', 'test']
201+
"""
136202

137203
def __init__(self, plugin: Plugin):
138204
"""Initialize a plugin reference.
139205
140206
Args:
141207
plugin: The plugin to reference.
208+
209+
Examples:
210+
>>> from mcpgateway.plugins.framework.models import PluginConfig, HookType
211+
>>> config = PluginConfig(
212+
... name="plugin_ref",
213+
... description="Test",
214+
... author="test",
215+
... kind="test.Plugin",
216+
... version="1.0.0",
217+
... hooks=[HookType.PROMPT_POST_FETCH],
218+
... tags=[]
219+
... )
220+
>>> plugin = Plugin(config)
221+
>>> ref = PluginRef(plugin)
222+
>>> ref._plugin.name
223+
'plugin_ref'
224+
>>> isinstance(ref._uuid, uuid.UUID)
225+
True
142226
"""
143227
self._plugin = plugin
144228
self._uuid = uuid.uuid4()

mcpgateway/plugins/framework/loader/config.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,28 @@
2020

2121

2222
class ConfigLoader:
23-
"""A configuration loader."""
23+
"""A configuration loader.
24+
25+
Examples:
26+
>>> import tempfile
27+
>>> import os
28+
>>> from mcpgateway.plugins.framework.models import PluginSettings
29+
>>> # Create a temporary config file
30+
>>> with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
31+
... _ = f.write(\"\"\"
32+
... plugin_settings:
33+
... enable_plugin_api: true
34+
... plugin_timeout: 30
35+
... plugin_dirs: ['/path/to/plugins']
36+
... \"\"\")
37+
... temp_path = f.name
38+
>>> try:
39+
... config = ConfigLoader.load_config(temp_path, use_jinja=False)
40+
... config.plugin_settings.enable_plugin_api
41+
... finally:
42+
... os.unlink(temp_path)
43+
True
44+
"""
2445

2546
@staticmethod
2647
def load_config(config: str, use_jinja: bool = True) -> Config:
@@ -32,6 +53,24 @@ def load_config(config: str, use_jinja: bool = True) -> Config:
3253
3354
Returns:
3455
The plugin configuration object.
56+
57+
Examples:
58+
>>> import tempfile
59+
>>> import os
60+
>>> with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
61+
... _ = f.write(\"\"\"
62+
... plugin_settings:
63+
... plugin_timeout: 60
64+
... enable_plugin_api: false
65+
... plugin_dirs: []
66+
... \"\"\")
67+
... temp_path = f.name
68+
>>> try:
69+
... cfg = ConfigLoader.load_config(temp_path, use_jinja=False)
70+
... cfg.plugin_settings.plugin_timeout
71+
... finally:
72+
... os.unlink(temp_path)
73+
60
3574
"""
3675
with open(os.path.normpath(config), "r", encoding="utf-8") as file:
3776
template = file.read()

mcpgateway/plugins/framework/loader/plugin.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,24 @@
2121

2222

2323
class PluginLoader(object):
24-
"""A plugin loader object for loading and instantiating plugins."""
24+
"""A plugin loader object for loading and instantiating plugins.
25+
26+
Examples:
27+
>>> loader = PluginLoader()
28+
>>> isinstance(loader._plugin_types, dict)
29+
True
30+
>>> len(loader._plugin_types)
31+
0
32+
"""
2533

2634
def __init__(self) -> None:
27-
"""Initialize the plugin loader."""
35+
"""Initialize the plugin loader.
36+
37+
Examples:
38+
>>> loader = PluginLoader()
39+
>>> loader._plugin_types
40+
{}
41+
"""
2842
self._plugin_types: dict[str, Type[Plugin]] = {}
2943

3044
def __get_plugin_type(self, kind: str) -> Type[Plugin]:

mcpgateway/plugins/framework/models.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ class HookType(str, Enum):
2323
Attributes:
2424
prompt_pre_fetch: The prompt pre hook.
2525
prompt_post_fetch: The prompt post hook.
26+
27+
Examples:
28+
>>> HookType.PROMPT_PRE_FETCH
29+
<HookType.PROMPT_PRE_FETCH: 'prompt_pre_fetch'>
30+
>>> HookType.PROMPT_PRE_FETCH.value
31+
'prompt_pre_fetch'
32+
>>> HookType('prompt_post_fetch')
33+
<HookType.PROMPT_POST_FETCH: 'prompt_post_fetch'>
34+
>>> list(HookType)
35+
[<HookType.PROMPT_PRE_FETCH: 'prompt_pre_fetch'>, <HookType.PROMPT_POST_FETCH: 'prompt_post_fetch'>]
2636
"""
2737

2838
PROMPT_PRE_FETCH = "prompt_pre_fetch"
@@ -36,6 +46,16 @@ class PluginMode(str, Enum):
3646
enforce: enforces the plugin result.
3747
permissive: audits the result.
3848
disabled: plugin disabled.
49+
50+
Examples:
51+
>>> PluginMode.ENFORCE
52+
<PluginMode.ENFORCE: 'enforce'>
53+
>>> PluginMode.PERMISSIVE.value
54+
'permissive'
55+
>>> PluginMode('disabled')
56+
<PluginMode.DISABLED: 'disabled'>
57+
>>> 'enforce' in [m.value for m in PluginMode]
58+
True
3959
"""
4060

4161
ENFORCE = "enforce"
@@ -50,6 +70,18 @@ class ToolTemplate(BaseModel):
5070
tool_name (str): the name of the tool.
5171
fields (Optional[list[str]]): the tool fields that are affected.
5272
result (bool): analyze tool output if true.
73+
74+
Examples:
75+
>>> tool = ToolTemplate(tool_name="my_tool")
76+
>>> tool.tool_name
77+
'my_tool'
78+
>>> tool.result
79+
False
80+
>>> tool2 = ToolTemplate(tool_name="analyzer", fields=["input", "params"], result=True)
81+
>>> tool2.fields
82+
['input', 'params']
83+
>>> tool2.result
84+
True
5385
"""
5486

5587
tool_name: str
@@ -64,6 +96,16 @@ class PromptTemplate(BaseModel):
6496
prompt_name (str): the name of the prompt.
6597
fields (Optional[list[str]]): the prompt fields that are affected.
6698
result (bool): analyze tool output if true.
99+
100+
Examples:
101+
>>> prompt = PromptTemplate(prompt_name="greeting")
102+
>>> prompt.prompt_name
103+
'greeting'
104+
>>> prompt.result
105+
False
106+
>>> prompt2 = PromptTemplate(prompt_name="question", fields=["context"], result=True)
107+
>>> prompt2.fields
108+
['context']
67109
"""
68110

69111
prompt_name: str
@@ -81,6 +123,17 @@ class PluginCondition(BaseModel):
81123
prompts (Optional[set[str]]): set of prompt names.
82124
user_pattern (Optional[list[str]]): list of user patterns.
83125
content_types (Optional[list[str]]): list of content types.
126+
127+
Examples:
128+
>>> cond = PluginCondition(server_ids={"server1", "server2"})
129+
>>> "server1" in cond.server_ids
130+
True
131+
>>> cond2 = PluginCondition(tools={"tool1"}, prompts={"prompt1"})
132+
>>> cond2.tools
133+
{'tool1'}
134+
>>> cond3 = PluginCondition(user_patterns=["admin", "root"])
135+
>>> len(cond3.user_patterns)
136+
2
84137
"""
85138

86139
server_ids: Optional[set[str]] = None
@@ -166,6 +219,21 @@ class PluginViolation(BaseModel):
166219
code (str): a violation code.
167220
details: (dict[str, Any]): additional violation details.
168221
_plugin_name (str): the plugin name, private attribute set by the plugin manager.
222+
223+
Examples:
224+
>>> violation = PluginViolation(
225+
... reason="Invalid input",
226+
... description="The input contains prohibited content",
227+
... code="PROHIBITED_CONTENT",
228+
... details={"field": "message", "value": "test"}
229+
... )
230+
>>> violation.reason
231+
'Invalid input'
232+
>>> violation.code
233+
'PROHIBITED_CONTENT'
234+
>>> violation.plugin_name = "content_filter"
235+
>>> violation.plugin_name
236+
'content_filter'
169237
"""
170238

171239
reason: str

0 commit comments

Comments
 (0)