Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/core/src/confirmation-bus/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export interface UpdatePolicy {
toolName: string;
persist?: boolean;
argsPattern?: string;
commandPrefix?: string;
commandPrefix?: string | string[];
mcpName?: string;
}

Expand Down
55 changes: 37 additions & 18 deletions packages/core/src/policy/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ interface TomlRule {
mcpName?: string;
decision?: string;
priority?: number;
commandPrefix?: string;
commandPrefix?: string | string[];
argsPattern?: string;
// Index signature to satisfy Record type if needed for toml.stringify
[key: string]: unknown;
Expand All @@ -259,26 +259,45 @@ export function createPolicyUpdater(
MessageBusType.UPDATE_POLICY,
async (message: UpdatePolicy) => {
const toolName = message.toolName;
let argsPattern = message.argsPattern
? new RegExp(message.argsPattern)
: undefined;

if (message.commandPrefix) {
// Convert commandPrefix to argsPattern for in-memory rule
// This mimics what toml-loader does
const escapedPrefix = escapeRegex(message.commandPrefix);
argsPattern = new RegExp(`"command":"${escapedPrefix}`);
}
// Convert commandPrefix(es) to argsPatterns for in-memory rules
const prefixes = Array.isArray(message.commandPrefix)
? message.commandPrefix
: [message.commandPrefix];

for (const prefix of prefixes) {
const escapedPrefix = escapeRegex(prefix);
// Use robust regex to match whole words (e.g. "git" but not "github")
const argsPattern = new RegExp(
`"command":"${escapedPrefix}(?:[\\s"]|$)`,
);

policyEngine.addRule({
toolName,
decision: PolicyDecision.ALLOW,
// User tier (2) + high priority (950/1000) = 2.95
// This ensures user "always allow" selections are high priority
// but still lose to admin policies (3.xxx) and settings excludes (200)
priority: 2.95,
argsPattern,
});
policyEngine.addRule({
toolName,
decision: PolicyDecision.ALLOW,
// User tier (2) + high priority (950/1000) = 2.95
// This ensures user "always allow" selections are high priority
// but still lose to admin policies (3.xxx) and settings excludes (200)
priority: 2.95,
argsPattern,
});
}
} else {
const argsPattern = message.argsPattern
? new RegExp(message.argsPattern)
: undefined;

policyEngine.addRule({
toolName,
decision: PolicyDecision.ALLOW,
// User tier (2) + high priority (950/1000) = 2.95
// This ensures user "always allow" selections are high priority
// but still lose to admin policies (3.xxx) and settings excludes (200)
priority: 2.95,
argsPattern,
});
}

if (message.persist) {
try {
Expand Down
4 changes: 3 additions & 1 deletion packages/core/src/policy/persistence.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ describe('createPolicyUpdater', () => {
const addedRule = rules.find((r) => r.toolName === toolName);
expect(addedRule).toBeDefined();
expect(addedRule?.priority).toBe(2.95);
expect(addedRule?.argsPattern).toEqual(new RegExp(`"command":"git status`));
expect(addedRule?.argsPattern).toEqual(
new RegExp(`"command":"git status(?:[\\s"]|$)`),
);

// Verify file written
expect(fs.writeFile).toHaveBeenCalledWith(
Expand Down
190 changes: 190 additions & 0 deletions packages/core/src/policy/policy-updater.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import * as fs from 'node:fs/promises';
import { createPolicyUpdater } from './config.js';
import { PolicyEngine } from './policy-engine.js';
import { MessageBus } from '../confirmation-bus/message-bus.js';
import { MessageBusType } from '../confirmation-bus/types.js';
import { Storage } from '../config/storage.js';
import toml from '@iarna/toml';
import { ShellToolInvocation } from '../tools/shell.js';
import { type Config } from '../config/config.js';
import {
ToolConfirmationOutcome,
type PolicyUpdateOptions,
} from '../tools/tools.js';
import * as shellUtils from '../utils/shell-utils.js';

vi.mock('node:fs/promises');
vi.mock('../config/storage.js');
vi.mock('../utils/shell-utils.js', () => ({
getCommandRoots: vi.fn(),
stripShellWrapper: vi.fn(),
}));
interface ParsedPolicy {
rule?: Array<{
commandPrefix?: string | string[];
}>;
}

interface TestableShellToolInvocation {
getPolicyUpdateOptions(
outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined;
}

describe('createPolicyUpdater', () => {
let policyEngine: PolicyEngine;
let messageBus: MessageBus;

beforeEach(() => {
vi.resetAllMocks();
policyEngine = new PolicyEngine({});
vi.spyOn(policyEngine, 'addRule');

messageBus = new MessageBus(policyEngine);
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(
'/mock/user/policies',
);
});

afterEach(() => {
vi.restoreAllMocks();
});

it('should add multiple rules when commandPrefix is an array', async () => {
createPolicyUpdater(policyEngine, messageBus);

await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: 'run_shell_command',
commandPrefix: ['echo', 'ls'],
persist: false,
});

expect(policyEngine.addRule).toHaveBeenCalledTimes(2);
expect(policyEngine.addRule).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
toolName: 'run_shell_command',
argsPattern: new RegExp('"command":"echo(?:[\\s"]|$)'),
}),
);
expect(policyEngine.addRule).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
toolName: 'run_shell_command',
argsPattern: new RegExp('"command":"ls(?:[\\s"]|$)'),
}),
);
});

it('should add a single rule when commandPrefix is a string', async () => {
createPolicyUpdater(policyEngine, messageBus);

await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: 'run_shell_command',
commandPrefix: 'git',
persist: false,
});

expect(policyEngine.addRule).toHaveBeenCalledTimes(1);
expect(policyEngine.addRule).toHaveBeenCalledWith(
expect.objectContaining({
toolName: 'run_shell_command',
argsPattern: new RegExp('"command":"git(?:[\\s"]|$)'),
}),
);
});

it('should persist multiple rules correctly to TOML', async () => {
createPolicyUpdater(policyEngine, messageBus);
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
vi.mocked(fs.rename).mockResolvedValue(undefined);

await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: 'run_shell_command',
commandPrefix: ['echo', 'ls'],
persist: true,
});

// Wait for the async listener to complete
await new Promise((resolve) => setTimeout(resolve, 0));

expect(fs.writeFile).toHaveBeenCalled();
const [_path, content] = vi.mocked(fs.writeFile).mock.calls[0] as [
string,
string,
];
const parsed = toml.parse(content) as unknown as ParsedPolicy;

expect(parsed.rule).toHaveLength(1);
expect(parsed.rule![0].commandPrefix).toEqual(['echo', 'ls']);
});
});

describe('ShellToolInvocation Policy Update', () => {
let mockConfig: Config;
let mockMessageBus: MessageBus;

beforeEach(() => {
vi.resetAllMocks();
mockConfig = {} as Config;
mockMessageBus = {} as MessageBus;

vi.mocked(shellUtils.stripShellWrapper).mockImplementation(
(c: string) => c,
);
});

it('should extract multiple root commands for chained commands', () => {
vi.mocked(shellUtils.getCommandRoots).mockReturnValue(['git', 'npm']);

const invocation = new ShellToolInvocation(
mockConfig,
{ command: 'git status && npm test' },
new Set(),
mockMessageBus,
'run_shell_command',
'Shell',
);

// Accessing protected method for testing
const options = (
invocation as unknown as TestableShellToolInvocation
).getPolicyUpdateOptions(ToolConfirmationOutcome.ProceedAlways);
expect(options!.commandPrefix).toEqual(['git', 'npm']);
expect(shellUtils.getCommandRoots).toHaveBeenCalledWith(
'git status && npm test',
);
});

it('should extract a single root command', () => {
vi.mocked(shellUtils.getCommandRoots).mockReturnValue(['ls']);

const invocation = new ShellToolInvocation(
mockConfig,
{ command: 'ls -la /tmp' },
new Set(),
mockMessageBus,
'run_shell_command',
'Shell',
);

// Accessing protected method for testing
const options = (
invocation as unknown as TestableShellToolInvocation
).getPolicyUpdateOptions(ToolConfirmationOutcome.ProceedAlways);
expect(options!.commandPrefix).toEqual(['ls']);
expect(shellUtils.getCommandRoots).toHaveBeenCalledWith('ls -la /tmp');
});
});
10 changes: 9 additions & 1 deletion packages/core/src/tools/shell.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,15 @@ export class ShellToolInvocation extends BaseToolInvocation<
protected override getPolicyUpdateOptions(
outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) {
if (
outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave ||
outcome === ToolConfirmationOutcome.ProceedAlways
) {
const command = stripShellWrapper(this.params.command);
const rootCommands = [...new Set(getCommandRoots(command))];
if (rootCommands.length > 0) {
return { commandPrefix: rootCommands };
}
return { commandPrefix: this.params.command };
}
return undefined;
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/tools/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export interface ToolInvocation<
* Options for policy updates that can be customized by tool invocations.
*/
export interface PolicyUpdateOptions {
commandPrefix?: string;
commandPrefix?: string | string[];
mcpName?: string;
}

Expand Down