-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathguardrail_plugin.py
More file actions
149 lines (121 loc) · 4.47 KB
/
guardrail_plugin.py
File metadata and controls
149 lines (121 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""ADK plugin: bind MCP tool org arguments to the authenticated caller.
If tool arguments contain no org-style parameters, the plugin does not run.
Otherwise every org value must equal ``get_request_org_id()`` from the JWT
request context; mismatches are blocked. If org parameters are present but
request ``org_id`` is missing, the call is blocked as well (cannot verify tenant).
Downstream MCP and APIs still enforce auth; this mitigates cross-tenant
argument injection at the agent layer.
"""
from __future__ import annotations
import logging
from typing import Any
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
from lightspeed_agent.auth.middleware import get_request_org_id
from lightspeed_agent.config import get_settings
logger = logging.getLogger(__name__)
# Normalized key names (lowercase, underscores) that may carry tenant org IDs.
_ORG_ARG_NAMES = frozenset(
{
"org_id",
"organization_id",
"rh_org_id",
"orgid",
"organizationid",
}
)
_BLOCK_CODE = "guardrail_org_mismatch"
_MSG_MISMATCH = (
"Tool blocked: organization parameter does not match the authenticated caller."
)
_MSG_NO_ORG_CONTEXT = (
"Tool blocked: organization parameter present but caller has no organization "
"context."
)
def _normalize_key(key: str) -> str:
return key.lower().replace("-", "_")
def _scalar_org_value(val: Any) -> str | None:
"""Extract a comparable org id string, or None if not a scalar."""
if val is None or isinstance(val, bool):
return None
if isinstance(val, str):
s = val.strip()
return s or None
if isinstance(val, int):
return str(val)
if isinstance(val, float):
return str(int(val)) if val.is_integer() else str(val)
return None
def _append_org_values_for_key(v: Any, found: list[str]) -> None:
"""Append scalar org id strings for *v* when the parent key is org-related.
``v`` may be a single scalar or a list/tuple of scalars (e.g. ``org_id``:
``["111", "222"]``). Non-scalar elements (dicts) are skipped here and left
to recursive traversal.
"""
if isinstance(v, list | tuple):
for item in v:
s = _scalar_org_value(item)
if s is not None:
found.append(s)
else:
s = _scalar_org_value(v)
if s is not None:
found.append(s)
def _collect_org_values(obj: Any) -> list[str]:
"""Recursively collect org-related scalar values from tool arguments."""
found: list[str] = []
if isinstance(obj, dict):
for k, v in obj.items():
if _normalize_key(str(k)) in _ORG_ARG_NAMES:
_append_org_values_for_key(v, found)
found.extend(_collect_org_values(v))
elif isinstance(obj, list):
for item in obj:
found.extend(_collect_org_values(item))
return found
def _block_response(message: str) -> dict[str, Any]:
return {
"error": message,
"code": _BLOCK_CODE,
"blocked": True,
}
class GuardrailPlugin(BasePlugin):
"""Plugin that enforces org/tenant consistency on tool inputs."""
def __init__(self) -> None:
super().__init__(name="guardrail")
async def before_tool_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> dict[str, Any] | None:
if not get_settings().guardrail_org_args_enabled:
return None
tool_name = getattr(tool, "name", type(tool).__name__)
org_values = _collect_org_values(tool_args)
if not org_values:
return None
expected = get_request_org_id()
if expected is None:
logger.warning(
"Guardrail blocked tool=%s invocation_id=%s: org args %s but "
"no request org_id",
tool_name,
tool_context.invocation_id,
org_values,
)
return _block_response(_MSG_NO_ORG_CONTEXT)
exp = expected.strip()
for value in org_values:
if value != exp:
logger.warning(
"Guardrail blocked tool=%s invocation_id=%s: org arg %r != %r",
tool_name,
tool_context.invocation_id,
value,
exp,
)
return _block_response(_MSG_MISMATCH)
return None