Skip to content

Commit df9b703

Browse files
araujofmonshri
authored andcommitted
feat: add opa policy input data mapping support (IBM#1102)
* feat: add opa policy input data mapping support Signed-off-by: Frederico Araujo <[email protected]> * chore: drop debugging print statement Signed-off-by: Frederico Araujo <[email protected]> --------- Signed-off-by: Frederico Araujo <[email protected]>
1 parent 02ea00b commit df9b703

File tree

6 files changed

+84
-68
lines changed

6 files changed

+84
-68
lines changed

plugins/external/opa/README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ plugins:
4343
extensions:
4444
policy: "example"
4545
policy_endpoint: "allow"
46+
# policy_input_data_map:
47+
# "context.git_context": "git_context"
48+
# "payload.args.repo_path": "repo_path"
4649
conditions:
4750
# Apply to specific tools/servers
4851
- server_ids: [] # Apply to all servers
@@ -55,12 +58,13 @@ The `applied_to` key in config.yaml, has been used to selectively apply policies
5558
Here, using this, you can provide the `name` of the tool you want to apply policy on, you can also provide
5659
context to the tool with the prefix `global` if it needs to check the context in global context provided.
5760
The key `opa_policy_context` is used to get context for policies and you can have multiple contexts within this key using `git_context` key.
58-
You can also provide policy within the `extensions` key where you can provide information to the plugin
59-
related to which policy to run and what endpoint to call for that policy.
60-
In the `config` key in `config.yaml` file OPAPlugin consists of the following things:
61+
62+
Under `extensions`, you can specify which policy to run and what endpoint to call for that policy. Optionally, an input data map can be specified to transform the input passed to the OPA policy. This works by mapping (transforming) the original input data onto a new representation. In the example above, the original input data `"input":{{"payload": {..., "args": {"repo_path": ..., ...}, "context": "git_context": {...}}, ...}}` is mapped to `"input":{"repo_path": ..., "git_context": {...}}`. Observe that the policy (rego file) must accept the input schema.
63+
64+
In the `config` key in `config.yaml` for the OPA plugin, the following attribute must be set to configure the OPA server endpoint:
6165
`opa_base_url` : It is the base url on which opa server is running.
6266

63-
3. Now suppose i have a sample policy, in `example.rego` file that allows a tool invocation only when "IBM" key word is present in the repo_path. Add the sample policy file or policy rego file that you defined, in `plugins/external/opa/opaserver/rego`.
67+
3. Now suppose you have a sample policy in `example.rego` file that allows a tool invocation only when "IBM" key word is present in the repo_path. Add the sample policy file or policy rego file that you defined, in `plugins/external/opa/opaserver/rego`.
6468

6569
3. Once you have your plugin defined in `config.yaml` and policy added in the rego file, run the following commands to build your OPA Plugin external MCP server using:
6670
* `make build`: This will build a docker image named `opapluginfilter`

plugins/external/opa/opapluginfilter/plugin.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
from typing import Any
1313

1414
# Third-Party
15+
from opapluginfilter.schema import BaseOPAInputKeys, OPAConfig, OPAInput
1516
import requests
1617

1718
# First-Party
1819
from mcpgateway.plugins.framework import (
1920
Plugin,
2021
PluginConfig,
2122
PluginContext,
23+
PluginViolation,
2224
PromptPosthookPayload,
2325
PromptPosthookResult,
2426
PromptPrehookPayload,
@@ -28,13 +30,7 @@
2830
ToolPreInvokePayload,
2931
ToolPreInvokeResult,
3032
)
31-
from mcpgateway.plugins.framework.models import PluginConfig, PluginViolation
3233
from mcpgateway.services.logging_service import LoggingService
33-
from opapluginfilter.schema import (
34-
BaseOPAInputKeys,
35-
OPAConfig,
36-
OPAInput
37-
)
3834

3935
# Initialize logging service first
4036
logging_service = LoggingService()
@@ -55,8 +51,29 @@ def __init__(self, config: PluginConfig):
5551
self.opa_config = OPAConfig.model_validate(self._config.config)
5652
self.opa_context_key = "opa_policy_context"
5753

54+
def _get_nested_value(self, data, key_string, default=None):
55+
"""
56+
Retrieves a value from a nested dictionary using a dot-notation string.
57+
58+
Args:
59+
data (dict): The dictionary to search within.
60+
key_string (str): The dot-notation string representing the path to the value.
61+
default (any, optional): The value to return if the key path is not found.
62+
Defaults to None.
63+
64+
Returns:
65+
any: The value at the specified key path, or the default value if not found.
66+
"""
67+
keys = key_string.split(".")
68+
current_data = data
69+
for key in keys:
70+
if isinstance(current_data, dict) and key in current_data:
71+
current_data = current_data[key]
72+
else:
73+
return default # Key not found at this level
74+
return current_data
5875

59-
def _evaluate_opa_policy(self, url: str, input: OPAInput) -> tuple[bool,Any]:
76+
def _evaluate_opa_policy(self, url: str, input: OPAInput, policy_input_data_map: dict) -> tuple[bool, Any]:
6077
"""Function to evaluate OPA policy. Makes a request to opa server with url and input.
6178
6279
Args:
@@ -70,16 +87,24 @@ def _evaluate_opa_policy(self, url: str, input: OPAInput) -> tuple[bool,Any]:
7087
7188
"""
7289

73-
payload = input.model_dump()
90+
def _key(k: str, m: str) -> str:
91+
return f"{k}.{m}" if k.split(".")[0] == "context" else k
92+
93+
payload = {"input": {m: self._get_nested_value(input.model_dump()["input"], _key(k, m)) for k, m in policy_input_data_map.items()}} if policy_input_data_map else input.model_dump()
7494
logger.info(f"OPA url {url}, OPA payload {payload}")
7595
rsp = requests.post(url, json=payload)
7696
logger.info(f"OPA connection response '{rsp}'")
7797
if rsp.status_code == 200:
7898
json_response = rsp.json()
79-
decision = json_response.get("result",None)
99+
decision = json_response.get("result", None)
80100
logger.info(f"OPA server response '{json_response}'")
81-
if isinstance(decision,bool):
101+
if isinstance(decision, bool):
102+
logger.debug(f"OPA decision {decision}")
82103
return decision, json_response
104+
elif isinstance(decision, dict) and "allow" in decision:
105+
allow = decision["allow"]
106+
logger.debug(f"OPA decision {allow}")
107+
return allow, json_response
83108
else:
84109
logger.debug(f"OPA sent a none response {json_response}")
85110
else:
@@ -128,34 +153,36 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginCo
128153
if not payload.args:
129154
return ToolPreInvokeResult()
130155

131-
132156
tool_context = []
133157
policy_context = {}
134158
tool_policy = None
135159
tool_policy_endpoint = None
160+
tool_policy_input_data_map = {}
136161
# Get the tool for which policy needs to be applied
137162
policy_apply_config = self._config.applied_to
138163
if policy_apply_config and policy_apply_config.tools:
139164
for tool in policy_apply_config.tools:
140165
tool_name = tool.tool_name
141166
if payload.name == tool_name:
142167
if tool.context:
143-
tool_context = [ctx.rsplit('.', 1)[-1] for ctx in tool.context]
168+
tool_context = [ctx.rsplit(".", 1)[-1] for ctx in tool.context]
144169
if self.opa_context_key in context.global_context.state:
145-
policy_context = {k : context.global_context.state[self.opa_context_key][k] for k in tool_context}
170+
policy_context = {k: context.global_context.state[self.opa_context_key][k] for k in tool_context}
146171
if tool.extensions:
147-
tool_policy = tool.extensions.get("policy",None)
148-
tool_policy_endpoint = tool.extensions.get("policy_endpoint",None)
172+
tool_policy = tool.extensions.get("policy", None)
173+
tool_policy_endpoint = tool.extensions.get("policy_endpoint", None)
174+
tool_policy_input_data_map = tool.extensions.get("policy_input_data_map", {})
149175

150-
opa_input = BaseOPAInputKeys(kind="tools/call", user = "none", payload=payload.model_dump(), context=policy_context, request_ip = "none", headers = {}, response = {})
151-
opa_server_url = "{opa_url}{policy}/{policy_endpoint}".format(opa_url = self.opa_config.opa_base_url, policy=tool_policy, policy_endpoint=tool_policy_endpoint)
152-
decision, decision_context = self._evaluate_opa_policy(url=opa_server_url,input=OPAInput(input=opa_input))
176+
opa_input = BaseOPAInputKeys(kind="tools/call", user="none", payload=payload.model_dump(), context=policy_context, request_ip="none", headers={}, response={})
177+
opa_server_url = "{opa_url}{policy}/{policy_endpoint}".format(opa_url=self.opa_config.opa_base_url, policy=tool_policy, policy_endpoint=tool_policy_endpoint)
178+
decision, decision_context = self._evaluate_opa_policy(url=opa_server_url, input=OPAInput(input=opa_input), policy_input_data_map=tool_policy_input_data_map)
153179
if not decision:
154180
violation = PluginViolation(
155181
reason="tool invocation not allowed",
156182
description="OPA policy denied for tool preinvocation",
157183
code="deny",
158-
details=decision_context,)
184+
details=decision_context,
185+
)
159186
return ToolPreInvokeResult(modified_payload=payload, violation=violation, continue_processing=False)
160187
return ToolPreInvokeResult(continue_processing=True)
161188

plugins/external/opa/opapluginfilter/schema.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
"""
1010

1111
# Standard
12-
from typing import Optional, Any
12+
from typing import Any, Optional
1313

1414
# Third-Party
1515
from pydantic import BaseModel
1616

17+
1718
class BaseOPAInputKeys(BaseModel):
1819
"""BaseOPAInputKeys
1920
@@ -34,11 +35,12 @@ class BaseOPAInputKeys(BaseModel):
3435
'{"opa_policy_context" : {"context1" : "value1"}}'
3536
3637
"""
37-
kind : Optional[str] = None
38-
user : Optional[str] = None
39-
request_ip : Optional[str] = None
40-
headers : Optional[dict[str, str]] = None
41-
response : Optional[dict[str, str]] = None
38+
39+
kind: Optional[str] = None
40+
user: Optional[str] = None
41+
request_ip: Optional[str] = None
42+
headers: Optional[dict[str, str]] = None
43+
response: Optional[dict[str, str]] = None
4244
payload: dict[str, Any]
4345
context: Optional[dict[str, Any]] = None
4446

@@ -57,7 +59,9 @@ class OPAInput(BaseModel):
5759
'{"opa_policy_context" : {"context1" : "value1"}}'
5860
5961
"""
60-
input : BaseOPAInputKeys
62+
63+
input: BaseOPAInputKeys
64+
6165

6266
class OPAConfig(BaseModel):
6367
"""Configuration for the OPA plugin."""

plugins/external/opa/tests/server/test_opa_server.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,17 @@
1010

1111

1212
# Standard
13+
from http.server import BaseHTTPRequestHandler, HTTPServer
1314
import json
1415
import threading
1516

16-
# Third-Party
17-
from http.server import BaseHTTPRequestHandler, HTTPServer
18-
1917

2018
# This class mocks up the post request for OPA server to evaluate policies.
2119
class MockOPAHandler(BaseHTTPRequestHandler):
2220
def do_POST(self):
2321
if self.path == "/v1/data/example/allow":
24-
content_length = int(self.headers.get('Content-Length', 0))
25-
post_body = self.rfile.read(content_length).decode('utf-8')
22+
content_length = int(self.headers.get("Content-Length", 0))
23+
post_body = self.rfile.read(content_length).decode("utf-8")
2624
try:
2725
data = json.loads(post_body)
2826
if "IBM" in data["input"]["payload"]["args"]["repo_path"]:
@@ -43,8 +41,9 @@ def do_POST(self):
4341
self.wfile.write(b"Invalid JSON")
4442
return
4543

44+
4645
# This creates a mock up server for OPA at port 8181
4746
def run_mock_opa():
48-
server = HTTPServer(('localhost', 8181), MockOPAHandler)
47+
server = HTTPServer(("localhost", 8181), MockOPAHandler)
4948
threading.Thread(target=server.serve_forever, daemon=True).start()
5049
return server

plugins/external/opa/tests/test_all.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
# -*- coding: utf-8 -*-
22
"""Tests for registered plugins."""
33

4-
# Third-Party
4+
# Standard
55
import asyncio
6+
7+
# Third-Party
68
import pytest
79

810
# First-Party
9-
from mcpgateway.models import Message, PromptResult, Role, TextContent
11+
from mcpgateway.models import Message, Role, TextContent
1012
from mcpgateway.plugins.framework import (
11-
PluginManager,
1213
GlobalContext,
13-
PromptPrehookPayload,
14+
PluginManager,
1415
PromptPosthookPayload,
16+
PromptPrehookPayload,
1517
PromptResult,
16-
ToolPreInvokePayload,
1718
ToolPostInvokePayload,
19+
ToolPreInvokePayload,
1820
)
1921

2022

plugins/external/opa/tests/test_opapluginfilter.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,23 @@
1010

1111

1212
# Third-Party
13+
from opapluginfilter.plugin import OPAPluginFilter
1314
import pytest
1415

1516
# First-Party
16-
from opapluginfilter.plugin import OPAPluginFilter
17-
from mcpgateway.plugins.framework import (
18-
PluginConfig,
19-
PluginContext,
20-
ToolPreInvokePayload,
21-
GlobalContext
22-
)
23-
from mcpgateway.plugins.framework.models import AppliedTo, ToolTemplate
17+
from mcpgateway.plugins.framework import GlobalContext, PluginConfig, PluginContext, ToolPreInvokePayload
2418

19+
# Local
2520
from tests.server.opa_server import run_mock_opa
2621

2722

2823
@pytest.mark.asyncio
2924
# Test for when opaplugin is not applied to a tool
3025
async def test_benign_opapluginfilter():
3126
"""Test plugin prompt prefetch hook."""
32-
config = PluginConfig(
33-
name="test",
34-
kind="opapluginfilter.OPAPluginFilter",
35-
hooks=["tool_pre_invoke"],
36-
config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}
37-
)
27+
config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"})
3828
mock_server = run_mock_opa()
3929

40-
4130
plugin = OPAPluginFilter(config)
4231

4332
# Test your plugin logic
@@ -52,12 +41,7 @@ async def test_benign_opapluginfilter():
5241
# Test for when opaplugin is not applied to a tool
5342
async def test_malign_opapluginfilter():
5443
"""Test plugin prompt prefetch hook."""
55-
config = PluginConfig(
56-
name="test",
57-
kind="opapluginfilter.OPAPluginFilter",
58-
hooks=["tool_pre_invoke"],
59-
config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}
60-
)
44+
config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"})
6145
mock_server = run_mock_opa()
6246
plugin = OPAPluginFilter(config)
6347

@@ -68,16 +52,12 @@ async def test_malign_opapluginfilter():
6852
mock_server.shutdown()
6953
assert not result.continue_processing and result.violation.code == "deny"
7054

55+
7156
@pytest.mark.asyncio
7257
# Test for opa plugin not applied to any of the tools
7358
async def test_applied_to_opaplugin():
7459
"""Test plugin prompt prefetch hook."""
75-
config = PluginConfig(
76-
name="test",
77-
kind="opapluginfilter.OPAPluginFilter",
78-
hooks=["tool_pre_invoke"],
79-
config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"}
80-
)
60+
config = PluginConfig(name="test", kind="opapluginfilter.OPAPluginFilter", hooks=["tool_pre_invoke"], config={"opa_base_url": "http://127.0.0.1:8181/v1/data/"})
8161
mock_server = run_mock_opa()
8262
plugin = OPAPluginFilter(config)
8363

0 commit comments

Comments
 (0)