Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions plugins/external/opa/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ plugins:
extensions:
policy: "example"
policy_endpoint: "allow"
# policy_input_data_map:
# "context.git_context": "git_context"
# "payload.args.repo_path": "repo_path"
conditions:
# Apply to specific tools/servers
- server_ids: [] # Apply to all servers
Expand All @@ -55,12 +58,13 @@ The `applied_to` key in config.yaml, has been used to selectively apply policies
Here, using this, you can provide the `name` of the tool you want to apply policy on, you can also provide
context to the tool with the prefix `global` if it needs to check the context in global context provided.
The key `opa_policy_context` is used to get context for policies and you can have multiple contexts within this key using `git_context` key.
You can also provide policy within the `extensions` key where you can provide information to the plugin
related to which policy to run and what endpoint to call for that policy.
In the `config` key in `config.yaml` file OPAPlugin consists of the following things:

Under `extensions`, you can specify which policy to run and what endpoint to call for that policy. Optionally, an input data map can be specified to transform the input passed to the OPA policy. This works by mapping (transforming) the original input data onto a new representation. In the example above, the original input data `"input":{{"payload": {..., "args": {"repo_path": ..., ...}, "context": "git_context": {...}}, ...}}` is mapped to `"input":{"repo_path": ..., "git_context": {...}}`. Observe that the policy (rego file) must accept the input schema.

In the `config` key in `config.yaml` for the OPA plugin, the following attribute must be set to configure the OPA server endpoint:
`opa_base_url` : It is the base url on which opa server is running.

3. Now suppose i have a sample policy, in `example.rego` file that allows a tool invocation only when "IBM" key word is present in the repo_path. Add the sample policy file or policy rego file that you defined, in `plugins/external/opa/opaserver/rego`.
3. Now suppose you have a sample policy in `example.rego` file that allows a tool invocation only when "IBM" key word is present in the repo_path. Add the sample policy file or policy rego file that you defined, in `plugins/external/opa/opaserver/rego`.

3. Once you have your plugin defined in `config.yaml` and policy added in the rego file, run the following commands to build your OPA Plugin external MCP server using:
* `make build`: This will build a docker image named `opapluginfilter`
Expand Down
65 changes: 46 additions & 19 deletions plugins/external/opa/opapluginfilter/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from typing import Any

# Third-Party
from opapluginfilter.schema import BaseOPAInputKeys, OPAConfig, OPAInput
import requests

# First-Party
from mcpgateway.plugins.framework import (
Plugin,
PluginConfig,
PluginContext,
PluginViolation,
PromptPosthookPayload,
PromptPosthookResult,
PromptPrehookPayload,
Expand All @@ -28,13 +30,7 @@
ToolPreInvokePayload,
ToolPreInvokeResult,
)
from mcpgateway.plugins.framework.models import PluginConfig, PluginViolation
from mcpgateway.services.logging_service import LoggingService
from opapluginfilter.schema import (
BaseOPAInputKeys,
OPAConfig,
OPAInput
)

# Initialize logging service first
logging_service = LoggingService()
Expand All @@ -55,8 +51,29 @@ def __init__(self, config: PluginConfig):
self.opa_config = OPAConfig.model_validate(self._config.config)
self.opa_context_key = "opa_policy_context"

def _get_nested_value(self, data, key_string, default=None):
"""
Retrieves a value from a nested dictionary using a dot-notation string.

Args:
data (dict): The dictionary to search within.
key_string (str): The dot-notation string representing the path to the value.
default (any, optional): The value to return if the key path is not found.
Defaults to None.

Returns:
any: The value at the specified key path, or the default value if not found.
"""
keys = key_string.split(".")
current_data = data
for key in keys:
if isinstance(current_data, dict) and key in current_data:
current_data = current_data[key]
else:
return default # Key not found at this level
return current_data

def _evaluate_opa_policy(self, url: str, input: OPAInput) -> tuple[bool,Any]:
def _evaluate_opa_policy(self, url: str, input: OPAInput, policy_input_data_map: dict) -> tuple[bool, Any]:
"""Function to evaluate OPA policy. Makes a request to opa server with url and input.

Args:
Expand All @@ -70,16 +87,24 @@ def _evaluate_opa_policy(self, url: str, input: OPAInput) -> tuple[bool,Any]:

"""

payload = input.model_dump()
def _key(k: str, m: str) -> str:
return f"{k}.{m}" if k.split(".")[0] == "context" else k

payload = {"input": {m: self._get_nested_value(input.model_dump()["input"], _key(k, m)) for k, m in policy_input_data_map.items()}} if policy_input_data_map else input.model_dump()
logger.info(f"OPA url {url}, OPA payload {payload}")
rsp = requests.post(url, json=payload)
logger.info(f"OPA connection response '{rsp}'")
if rsp.status_code == 200:
json_response = rsp.json()
decision = json_response.get("result",None)
decision = json_response.get("result", None)
logger.info(f"OPA server response '{json_response}'")
if isinstance(decision,bool):
if isinstance(decision, bool):
logger.debug(f"OPA decision {decision}")
return decision, json_response
elif isinstance(decision, dict) and "allow" in decision:
allow = decision["allow"]
logger.debug(f"OPA decision {allow}")
return allow, json_response
else:
logger.debug(f"OPA sent a none response {json_response}")
else:
Expand Down Expand Up @@ -128,34 +153,36 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginCo
if not payload.args:
return ToolPreInvokeResult()


tool_context = []
policy_context = {}
tool_policy = None
tool_policy_endpoint = None
tool_policy_input_data_map = {}
# Get the tool for which policy needs to be applied
policy_apply_config = self._config.applied_to
if policy_apply_config and policy_apply_config.tools:
for tool in policy_apply_config.tools:
tool_name = tool.tool_name
if payload.name == tool_name:
if tool.context:
tool_context = [ctx.rsplit('.', 1)[-1] for ctx in tool.context]
tool_context = [ctx.rsplit(".", 1)[-1] for ctx in tool.context]
if self.opa_context_key in context.global_context.state:
policy_context = {k : context.global_context.state[self.opa_context_key][k] for k in tool_context}
policy_context = {k: context.global_context.state[self.opa_context_key][k] for k in tool_context}
if tool.extensions:
tool_policy = tool.extensions.get("policy",None)
tool_policy_endpoint = tool.extensions.get("policy_endpoint",None)
tool_policy = tool.extensions.get("policy", None)
tool_policy_endpoint = tool.extensions.get("policy_endpoint", None)
tool_policy_input_data_map = tool.extensions.get("policy_input_data_map", {})

opa_input = BaseOPAInputKeys(kind="tools/call", user = "none", payload=payload.model_dump(), context=policy_context, request_ip = "none", headers = {}, response = {})
opa_server_url = "{opa_url}{policy}/{policy_endpoint}".format(opa_url = self.opa_config.opa_base_url, policy=tool_policy, policy_endpoint=tool_policy_endpoint)
decision, decision_context = self._evaluate_opa_policy(url=opa_server_url,input=OPAInput(input=opa_input))
opa_input = BaseOPAInputKeys(kind="tools/call", user="none", payload=payload.model_dump(), context=policy_context, request_ip="none", headers={}, response={})
opa_server_url = "{opa_url}{policy}/{policy_endpoint}".format(opa_url=self.opa_config.opa_base_url, policy=tool_policy, policy_endpoint=tool_policy_endpoint)
decision, decision_context = self._evaluate_opa_policy(url=opa_server_url, input=OPAInput(input=opa_input), policy_input_data_map=tool_policy_input_data_map)
if not decision:
violation = PluginViolation(
reason="tool invocation not allowed",
description="OPA policy denied for tool preinvocation",
code="deny",
details=decision_context,)
details=decision_context,
)
return ToolPreInvokeResult(modified_payload=payload, violation=violation, continue_processing=False)
return ToolPreInvokeResult(continue_processing=True)

Expand Down
18 changes: 11 additions & 7 deletions plugins/external/opa/opapluginfilter/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
"""

# Standard
from typing import Optional, Any
from typing import Any, Optional

# Third-Party
from pydantic import BaseModel


class BaseOPAInputKeys(BaseModel):
"""BaseOPAInputKeys

Expand All @@ -34,11 +35,12 @@ class BaseOPAInputKeys(BaseModel):
'{"opa_policy_context" : {"context1" : "value1"}}'

"""
kind : Optional[str] = None
user : Optional[str] = None
request_ip : Optional[str] = None
headers : Optional[dict[str, str]] = None
response : Optional[dict[str, str]] = None

kind: Optional[str] = None
user: Optional[str] = None
request_ip: Optional[str] = None
headers: Optional[dict[str, str]] = None
response: Optional[dict[str, str]] = None
payload: dict[str, Any]
context: Optional[dict[str, Any]] = None

Expand All @@ -57,7 +59,9 @@ class OPAInput(BaseModel):
'{"opa_policy_context" : {"context1" : "value1"}}'

"""
input : BaseOPAInputKeys

input: BaseOPAInputKeys


class OPAConfig(BaseModel):
"""Configuration for the OPA plugin."""
Expand Down
11 changes: 5 additions & 6 deletions plugins/external/opa/tests/server/test_opa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,17 @@


# Standard
from http.server import BaseHTTPRequestHandler, HTTPServer
import json
import threading

# Third-Party
from http.server import BaseHTTPRequestHandler, HTTPServer


# This class mocks up the post request for OPA server to evaluate policies.
class MockOPAHandler(BaseHTTPRequestHandler):
def do_POST(self):
if self.path == "/v1/data/example/allow":
content_length = int(self.headers.get('Content-Length', 0))
post_body = self.rfile.read(content_length).decode('utf-8')
content_length = int(self.headers.get("Content-Length", 0))
post_body = self.rfile.read(content_length).decode("utf-8")
try:
data = json.loads(post_body)
if "IBM" in data["input"]["payload"]["args"]["repo_path"]:
Expand All @@ -43,8 +41,9 @@ def do_POST(self):
self.wfile.write(b"Invalid JSON")
return


# This creates a mock up server for OPA at port 8181
def run_mock_opa():
server = HTTPServer(('localhost', 8181), MockOPAHandler)
server = HTTPServer(("localhost", 8181), MockOPAHandler)
threading.Thread(target=server.serve_forever, daemon=True).start()
return server
12 changes: 7 additions & 5 deletions plugins/external/opa/tests/test_all.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# -*- coding: utf-8 -*-
"""Tests for registered plugins."""

# Third-Party
# Standard
import asyncio

# Third-Party
import pytest

# First-Party
from mcpgateway.models import Message, PromptResult, Role, TextContent
from mcpgateway.models import Message, Role, TextContent
from mcpgateway.plugins.framework import (
PluginManager,
GlobalContext,
PromptPrehookPayload,
PluginManager,
PromptPosthookPayload,
PromptPrehookPayload,
PromptResult,
ToolPreInvokePayload,
ToolPostInvokePayload,
ToolPreInvokePayload,
)


Expand Down
34 changes: 7 additions & 27 deletions plugins/external/opa/tests/test_opapluginfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,23 @@


# Third-Party
from opapluginfilter.plugin import OPAPluginFilter
import pytest

# First-Party
from opapluginfilter.plugin import OPAPluginFilter
from mcpgateway.plugins.framework import (
PluginConfig,
PluginContext,
ToolPreInvokePayload,
GlobalContext
)
from mcpgateway.plugins.framework.models import AppliedTo, ToolTemplate
from mcpgateway.plugins.framework import GlobalContext, PluginConfig, PluginContext, ToolPreInvokePayload

# Local
from tests.server.opa_server import run_mock_opa


@pytest.mark.asyncio
# Test for when opaplugin is not applied to a tool
async def test_benign_opapluginfilter():
"""Test plugin prompt prefetch hook."""
config = PluginConfig(
name="test",
kind="opapluginfilter.OPAPluginFilter",
hooks=["tool_pre_invoke"],
config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}
)
config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"})
mock_server = run_mock_opa()


plugin = OPAPluginFilter(config)

# Test your plugin logic
Expand All @@ -52,12 +41,7 @@ async def test_benign_opapluginfilter():
# Test for when opaplugin is not applied to a tool
async def test_malign_opapluginfilter():
"""Test plugin prompt prefetch hook."""
config = PluginConfig(
name="test",
kind="opapluginfilter.OPAPluginFilter",
hooks=["tool_pre_invoke"],
config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}
)
config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"})
mock_server = run_mock_opa()
plugin = OPAPluginFilter(config)

Expand All @@ -68,16 +52,12 @@ async def test_malign_opapluginfilter():
mock_server.shutdown()
assert not result.continue_processing and result.violation.code == "deny"


@pytest.mark.asyncio
# Test for opa plugin not applied to any of the tools
async def test_applied_to_opaplugin():
"""Test plugin prompt prefetch hook."""
config = PluginConfig(
name="test",
kind="opapluginfilter.OPAPluginFilter",
hooks=["tool_pre_invoke"],
config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}
)
config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"})
mock_server = run_mock_opa()
plugin = OPAPluginFilter(config)

Expand Down
Loading