diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index 6f2f9a2e12f..a26b786701c 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -41,7 +41,7 @@ export interface UpdatePolicy { toolName: string; persist?: boolean; argsPattern?: string; - commandPrefix?: string; + commandPrefix?: string | string[]; mcpName?: string; } diff --git a/packages/core/src/policy/config.ts b/packages/core/src/policy/config.ts index 9057506ae2a..c4888c48821 100644 --- a/packages/core/src/policy/config.ts +++ b/packages/core/src/policy/config.ts @@ -245,7 +245,7 @@ interface TomlRule { mcpName?: string; decision?: string; priority?: number; - commandPrefix?: string; + commandPrefix?: string | string[]; argsPattern?: string; // Index signature to satisfy Record type if needed for toml.stringify [key: string]: unknown; @@ -259,26 +259,45 @@ export function createPolicyUpdater( MessageBusType.UPDATE_POLICY, async (message: UpdatePolicy) => { const toolName = message.toolName; - let argsPattern = message.argsPattern - ? new RegExp(message.argsPattern) - : undefined; if (message.commandPrefix) { - // Convert commandPrefix to argsPattern for in-memory rule - // This mimics what toml-loader does - const escapedPrefix = escapeRegex(message.commandPrefix); - argsPattern = new RegExp(`"command":"${escapedPrefix}`); - } + // Convert commandPrefix(es) to argsPatterns for in-memory rules + const prefixes = Array.isArray(message.commandPrefix) + ? message.commandPrefix + : [message.commandPrefix]; + + for (const prefix of prefixes) { + const escapedPrefix = escapeRegex(prefix); + // Use robust regex to match whole words (e.g. "git" but not "github") + const argsPattern = new RegExp( + `"command":"${escapedPrefix}(?:[\\s"]|$)`, + ); - policyEngine.addRule({ - toolName, - decision: PolicyDecision.ALLOW, - // User tier (2) + high priority (950/1000) = 2.95 - // This ensures user "always allow" selections are high priority - // but still lose to admin policies (3.xxx) and settings excludes (200) - priority: 2.95, - argsPattern, - }); + policyEngine.addRule({ + toolName, + decision: PolicyDecision.ALLOW, + // User tier (2) + high priority (950/1000) = 2.95 + // This ensures user "always allow" selections are high priority + // but still lose to admin policies (3.xxx) and settings excludes (200) + priority: 2.95, + argsPattern, + }); + } + } else { + const argsPattern = message.argsPattern + ? new RegExp(message.argsPattern) + : undefined; + + policyEngine.addRule({ + toolName, + decision: PolicyDecision.ALLOW, + // User tier (2) + high priority (950/1000) = 2.95 + // This ensures user "always allow" selections are high priority + // but still lose to admin policies (3.xxx) and settings excludes (200) + priority: 2.95, + argsPattern, + }); + } if (message.persist) { try { diff --git a/packages/core/src/policy/persistence.test.ts b/packages/core/src/policy/persistence.test.ts index 4954d7280ee..7743de752fc 100644 --- a/packages/core/src/policy/persistence.test.ts +++ b/packages/core/src/policy/persistence.test.ts @@ -126,7 +126,9 @@ describe('createPolicyUpdater', () => { const addedRule = rules.find((r) => r.toolName === toolName); expect(addedRule).toBeDefined(); expect(addedRule?.priority).toBe(2.95); - expect(addedRule?.argsPattern).toEqual(new RegExp(`"command":"git status`)); + expect(addedRule?.argsPattern).toEqual( + new RegExp(`"command":"git status(?:[\\s"]|$)`), + ); // Verify file written expect(fs.writeFile).toHaveBeenCalledWith( diff --git a/packages/core/src/policy/policy-updater.test.ts b/packages/core/src/policy/policy-updater.test.ts new file mode 100644 index 00000000000..acde845e3a0 --- /dev/null +++ b/packages/core/src/policy/policy-updater.test.ts @@ -0,0 +1,190 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import * as fs from 'node:fs/promises'; +import { createPolicyUpdater } from './config.js'; +import { PolicyEngine } from './policy-engine.js'; +import { MessageBus } from '../confirmation-bus/message-bus.js'; +import { MessageBusType } from '../confirmation-bus/types.js'; +import { Storage } from '../config/storage.js'; +import toml from '@iarna/toml'; +import { ShellToolInvocation } from '../tools/shell.js'; +import { type Config } from '../config/config.js'; +import { + ToolConfirmationOutcome, + type PolicyUpdateOptions, +} from '../tools/tools.js'; +import * as shellUtils from '../utils/shell-utils.js'; + +vi.mock('node:fs/promises'); +vi.mock('../config/storage.js'); +vi.mock('../utils/shell-utils.js', () => ({ + getCommandRoots: vi.fn(), + stripShellWrapper: vi.fn(), +})); +interface ParsedPolicy { + rule?: Array<{ + commandPrefix?: string | string[]; + }>; +} + +interface TestableShellToolInvocation { + getPolicyUpdateOptions( + outcome: ToolConfirmationOutcome, + ): PolicyUpdateOptions | undefined; +} + +describe('createPolicyUpdater', () => { + let policyEngine: PolicyEngine; + let messageBus: MessageBus; + + beforeEach(() => { + vi.resetAllMocks(); + policyEngine = new PolicyEngine({}); + vi.spyOn(policyEngine, 'addRule'); + + messageBus = new MessageBus(policyEngine); + vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue( + '/mock/user/policies', + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should add multiple rules when commandPrefix is an array', async () => { + createPolicyUpdater(policyEngine, messageBus); + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'run_shell_command', + commandPrefix: ['echo', 'ls'], + persist: false, + }); + + expect(policyEngine.addRule).toHaveBeenCalledTimes(2); + expect(policyEngine.addRule).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + toolName: 'run_shell_command', + argsPattern: new RegExp('"command":"echo(?:[\\s"]|$)'), + }), + ); + expect(policyEngine.addRule).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + toolName: 'run_shell_command', + argsPattern: new RegExp('"command":"ls(?:[\\s"]|$)'), + }), + ); + }); + + it('should add a single rule when commandPrefix is a string', async () => { + createPolicyUpdater(policyEngine, messageBus); + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'run_shell_command', + commandPrefix: 'git', + persist: false, + }); + + expect(policyEngine.addRule).toHaveBeenCalledTimes(1); + expect(policyEngine.addRule).toHaveBeenCalledWith( + expect.objectContaining({ + toolName: 'run_shell_command', + argsPattern: new RegExp('"command":"git(?:[\\s"]|$)'), + }), + ); + }); + + it('should persist multiple rules correctly to TOML', async () => { + createPolicyUpdater(policyEngine, messageBus); + vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); + vi.mocked(fs.mkdir).mockResolvedValue(undefined); + vi.mocked(fs.writeFile).mockResolvedValue(undefined); + vi.mocked(fs.rename).mockResolvedValue(undefined); + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'run_shell_command', + commandPrefix: ['echo', 'ls'], + persist: true, + }); + + // Wait for the async listener to complete + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(fs.writeFile).toHaveBeenCalled(); + const [_path, content] = vi.mocked(fs.writeFile).mock.calls[0] as [ + string, + string, + ]; + const parsed = toml.parse(content) as unknown as ParsedPolicy; + + expect(parsed.rule).toHaveLength(1); + expect(parsed.rule![0].commandPrefix).toEqual(['echo', 'ls']); + }); +}); + +describe('ShellToolInvocation Policy Update', () => { + let mockConfig: Config; + let mockMessageBus: MessageBus; + + beforeEach(() => { + vi.resetAllMocks(); + mockConfig = {} as Config; + mockMessageBus = {} as MessageBus; + + vi.mocked(shellUtils.stripShellWrapper).mockImplementation( + (c: string) => c, + ); + }); + + it('should extract multiple root commands for chained commands', () => { + vi.mocked(shellUtils.getCommandRoots).mockReturnValue(['git', 'npm']); + + const invocation = new ShellToolInvocation( + mockConfig, + { command: 'git status && npm test' }, + new Set(), + mockMessageBus, + 'run_shell_command', + 'Shell', + ); + + // Accessing protected method for testing + const options = ( + invocation as unknown as TestableShellToolInvocation + ).getPolicyUpdateOptions(ToolConfirmationOutcome.ProceedAlways); + expect(options!.commandPrefix).toEqual(['git', 'npm']); + expect(shellUtils.getCommandRoots).toHaveBeenCalledWith( + 'git status && npm test', + ); + }); + + it('should extract a single root command', () => { + vi.mocked(shellUtils.getCommandRoots).mockReturnValue(['ls']); + + const invocation = new ShellToolInvocation( + mockConfig, + { command: 'ls -la /tmp' }, + new Set(), + mockMessageBus, + 'run_shell_command', + 'Shell', + ); + + // Accessing protected method for testing + const options = ( + invocation as unknown as TestableShellToolInvocation + ).getPolicyUpdateOptions(ToolConfirmationOutcome.ProceedAlways); + expect(options!.commandPrefix).toEqual(['ls']); + expect(shellUtils.getCommandRoots).toHaveBeenCalledWith('ls -la /tmp'); + }); +}); diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index b4f79a8b0c1..9103480c5d0 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -89,7 +89,15 @@ export class ShellToolInvocation extends BaseToolInvocation< protected override getPolicyUpdateOptions( outcome: ToolConfirmationOutcome, ): PolicyUpdateOptions | undefined { - if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) { + if ( + outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave || + outcome === ToolConfirmationOutcome.ProceedAlways + ) { + const command = stripShellWrapper(this.params.command); + const rootCommands = [...new Set(getCommandRoots(command))]; + if (rootCommands.length > 0) { + return { commandPrefix: rootCommands }; + } return { commandPrefix: this.params.command }; } return undefined; diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 2c5c8018b39..d4b7fc3094e 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -69,7 +69,7 @@ export interface ToolInvocation< * Options for policy updates that can be customized by tool invocations. */ export interface PolicyUpdateOptions { - commandPrefix?: string; + commandPrefix?: string | string[]; mcpName?: string; }