Skip to content

Commit 6c28ee5

Browse files
feat: session policy pinning + tool alias registry (#96)
- Deep-copy policy in create_context() so sessions get pinned snapshots that aren't mutated by later policy changes (Closes #92) - Add ToolAliasRegistry with default canonical mappings for 7 tool families (web_search, file_read/write, shell_execute, code_execute, database_query, http_request) — prevents policy bypass via tool renaming (Closes #94) - Export ToolAliasRegistry from integrations __init__ - 20 tests covering pinning isolation and alias bypass prevention Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent cc90928 commit 6c28ee5

File tree

4 files changed

+385
-2
lines changed

4 files changed

+385
-2
lines changed

packages/agent-os/src/agent_os/integrations/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
from .rate_limiter import RateLimiter, RateLimitStatus
115115
from .templates import PolicyTemplates
116116
from .token_budget import TokenBudgetStatus, TokenBudgetTracker
117+
from .tool_aliases import ToolAliasRegistry
117118
from .webhooks import DeliveryRecord, WebhookConfig, WebhookEvent, WebhookNotifier
118119

119120
__all__ = [
@@ -194,6 +195,8 @@
194195
"check_compatibility",
195196
"CompatReport",
196197
"warn_on_import",
198+
# Tool Aliases
199+
"ToolAliasRegistry",
197200
# Rate Limiting
198201
"RateLimiter",
199202
"RateLimitStatus",

packages/agent-os/src/agent_os/integrations/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import annotations
1010

1111
import asyncio
12+
import copy
1213
import difflib
1314
import fnmatch
1415
import hashlib
@@ -822,12 +823,17 @@ def unwrap(self, governed_agent: Any) -> Any:
822823
pass
823824

824825
def create_context(self, agent_id: str) -> ExecutionContext:
825-
"""Create execution context for an agent."""
826+
"""Create execution context for an agent.
827+
828+
The policy is **deep-copied** so that the session is pinned to
829+
the policy that was active when the context was created. This
830+
prevents mid-session mutations from leaking into running sessions.
831+
"""
826832
from uuid import uuid4
827833
ctx = ExecutionContext(
828834
agent_id=agent_id,
829835
session_id=str(uuid4())[:8],
830-
policy=self.policy
836+
policy=copy.deepcopy(self.policy),
831837
)
832838
self.contexts[agent_id] = ctx
833839
return ctx
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""
4+
Tool Alias Registry for Capability Canonicalization.
5+
6+
Maps tool name variants to canonical capability identifiers so that
7+
policy allowlists/blocklists cannot be bypassed by renaming tools.
8+
9+
Usage:
10+
from agent_os.integrations.tool_aliases import ToolAliasRegistry
11+
12+
registry = ToolAliasRegistry()
13+
registry.register_alias("bing_search", "web_search")
14+
registry.register_alias("search_web", "web_search")
15+
registry.register_alias("google_search", "web_search")
16+
17+
assert registry.canonicalize("bing_search") == "web_search"
18+
assert registry.canonicalize("unknown_tool") == "unknown_tool"
19+
"""
20+
21+
from __future__ import annotations
22+
23+
import logging
24+
import re
25+
from typing import Optional
26+
27+
logger = logging.getLogger(__name__)
28+
29+
# Default canonical mappings for common tool families.
30+
# Keys are alias patterns, values are canonical names.
31+
DEFAULT_ALIASES: dict[str, str] = {
32+
# Search tools
33+
"bing_search": "web_search",
34+
"google_search": "web_search",
35+
"search_web": "web_search",
36+
"internet_search": "web_search",
37+
"duckduckgo_search": "web_search",
38+
# File operations
39+
"read_file": "file_read",
40+
"file_read": "file_read",
41+
"get_file": "file_read",
42+
"load_file": "file_read",
43+
"write_file": "file_write",
44+
"file_write": "file_write",
45+
"save_file": "file_write",
46+
"create_file": "file_write",
47+
# Shell execution
48+
"shell_exec": "shell_execute",
49+
"shell_execute": "shell_execute",
50+
"run_command": "shell_execute",
51+
"exec_command": "shell_execute",
52+
"bash": "shell_execute",
53+
"terminal": "shell_execute",
54+
# Code execution
55+
"python_exec": "code_execute",
56+
"run_python": "code_execute",
57+
"execute_code": "code_execute",
58+
"eval_code": "code_execute",
59+
# Database operations
60+
"sql_query": "database_query",
61+
"run_sql": "database_query",
62+
"execute_sql": "database_query",
63+
"db_query": "database_query",
64+
# HTTP operations
65+
"http_request": "http_request",
66+
"api_call": "http_request",
67+
"fetch_url": "http_request",
68+
"curl": "http_request",
69+
}
70+
71+
72+
class ToolAliasRegistry:
73+
"""Maps tool name variants to canonical capability identifiers.
74+
75+
Provides both exact-match aliases and regex pattern-based matching
76+
for tool name canonicalization. Prevents policy bypass via tool
77+
renaming.
78+
79+
Args:
80+
use_defaults: If True, loads the default alias mappings.
81+
"""
82+
83+
def __init__(self, use_defaults: bool = True) -> None:
84+
self._aliases: dict[str, str] = {}
85+
self._patterns: list[tuple[re.Pattern, str]] = []
86+
if use_defaults:
87+
self._aliases.update(DEFAULT_ALIASES)
88+
89+
def register_alias(self, alias: str, canonical: str) -> None:
90+
"""Register a tool name alias.
91+
92+
Args:
93+
alias: The alternative tool name (case-insensitive).
94+
canonical: The canonical capability name it maps to.
95+
"""
96+
self._aliases[alias.lower()] = canonical.lower()
97+
98+
def register_pattern(self, pattern: str, canonical: str) -> None:
99+
"""Register a regex pattern that maps matching tool names.
100+
101+
Args:
102+
pattern: Regex pattern to match tool names against.
103+
canonical: The canonical capability name for matches.
104+
"""
105+
self._patterns.append((re.compile(pattern, re.IGNORECASE), canonical.lower()))
106+
107+
def canonicalize(self, tool_name: str) -> str:
108+
"""Resolve a tool name to its canonical form.
109+
110+
Checks exact aliases first, then regex patterns. Returns the
111+
original name (lowercased) if no mapping is found.
112+
113+
Args:
114+
tool_name: The tool name to canonicalize.
115+
116+
Returns:
117+
The canonical capability name.
118+
"""
119+
lower = tool_name.lower()
120+
121+
# Exact match first
122+
if lower in self._aliases:
123+
return self._aliases[lower]
124+
125+
# Pattern match
126+
for pattern, canonical in self._patterns:
127+
if pattern.search(lower):
128+
return canonical
129+
130+
return lower
131+
132+
def is_allowed(self, tool_name: str, allowed_tools: list[str]) -> bool:
133+
"""Check if a tool is in the allowed list after canonicalization.
134+
135+
Both the tool name and all entries in the allowed list are
136+
canonicalized before comparison.
137+
138+
Args:
139+
tool_name: Tool name to check.
140+
allowed_tools: List of allowed tool names/capabilities.
141+
142+
Returns:
143+
True if the canonicalized tool is in the canonicalized allowlist.
144+
"""
145+
if not allowed_tools:
146+
return True # Empty allowlist = all allowed
147+
canonical = self.canonicalize(tool_name)
148+
allowed_canonical = {self.canonicalize(t) for t in allowed_tools}
149+
return canonical in allowed_canonical
150+
151+
def is_blocked(self, tool_name: str, blocked_tools: list[str]) -> bool:
152+
"""Check if a tool is in a block list after canonicalization.
153+
154+
Args:
155+
tool_name: Tool name to check.
156+
blocked_tools: List of blocked tool names/capabilities.
157+
158+
Returns:
159+
True if the canonicalized tool is in the canonicalized blocklist.
160+
"""
161+
if not blocked_tools:
162+
return False
163+
canonical = self.canonicalize(tool_name)
164+
blocked_canonical = {self.canonicalize(t) for t in blocked_tools}
165+
return canonical in blocked_canonical
166+
167+
def get_aliases(self, canonical: str) -> list[str]:
168+
"""Get all known aliases for a canonical tool name.
169+
170+
Args:
171+
canonical: The canonical capability name.
172+
173+
Returns:
174+
List of alias names that map to this canonical name.
175+
"""
176+
canonical_lower = canonical.lower()
177+
return [
178+
alias
179+
for alias, canon in self._aliases.items()
180+
if canon == canonical_lower
181+
]
182+
183+
def list_canonical_tools(self) -> list[str]:
184+
"""List all unique canonical tool names."""
185+
return sorted(set(self._aliases.values()))
186+
187+
def __len__(self) -> int:
188+
return len(self._aliases)
189+
190+
def __contains__(self, tool_name: str) -> bool:
191+
return tool_name.lower() in self._aliases

0 commit comments

Comments
 (0)