diff --git a/conf.json b/conf.json index 940c8bdd6..932f457b0 100644 --- a/conf.json +++ b/conf.json @@ -10,7 +10,8 @@ "pangea", "promptsecurity", "panw-prisma-airs", - "walledai" + "walledai", + "trend-ai" ], "credentials": { "portkey": { diff --git a/plugins/index.ts b/plugins/index.ts index 8d4dcb1b8..2828836c1 100644 --- a/plugins/index.ts +++ b/plugins/index.ts @@ -67,6 +67,8 @@ import { handler as defaultregexReplace } from './default/regexReplace'; import { handler as defaultallowedRequestTypes } from './default/allowedRequestTypes'; import { handler as javelinguardrails } from './javelin/guardrails'; import { handler as f5GuardrailsScan } from './f5-guardrails/scan'; +import { handler as trendAiGuard } from './trend-ai/guard'; + export const plugins = { default: { regexMatch: defaultregexMatch, @@ -177,4 +179,7 @@ export const plugins = { 'f5-guardrails': { scan: f5GuardrailsScan, }, + 'trend-ai': { + guard: trendAiGuard, + }, }; diff --git a/plugins/trend-ai/guard.ts b/plugins/trend-ai/guard.ts new file mode 100644 index 000000000..9e7df1791 --- /dev/null +++ b/plugins/trend-ai/guard.ts @@ -0,0 +1,122 @@ +import { + HookEventType, + PluginContext, + PluginHandler, + PluginParameters, +} from '../types'; +import { post, getText, HttpError } from '../utils'; +import { VERSION } from './version'; + +export const handler: PluginHandler = async ( + context: PluginContext, + parameters: PluginParameters, + eventType: HookEventType, + options?: { + env: Record; + getFromCacheByKey?: (key: string) => Promise; + putInCacheWithValue?: (key: string, value: any) => Promise; + } +) => { + let error = null; + let verdict = true; + let data = null; + + // Validate required parameters + if (!parameters.credentials?.v1Url) { + return { + error: { message: `'parameters.credentials.v1Url' must be set` }, + verdict: true, + data, + }; + } + + if (!parameters.credentials?.apiKey) { + return { + error: { message: `'parameters.credentials.apiKey' must be set` }, + verdict: true, + data, + }; + } + + // Extract text from context + const text = getText(context, eventType); + if (!text) { + return { + error: { message: 'request or response text is empty' }, + verdict: true, + data, + }; + } + const applicationName = parameters.applicationName; + + // Validate application name is provided and has correct format + if (!applicationName) { + return { + error: { message: 'Application name is required' }, + verdict: true, + data, + }; + } + + if (!/^[a-zA-Z0-9_-]+$/.test(applicationName)) { + return { + error: { + message: + 'Application name must contain only letters, numbers, hyphens, and underscores', + }, + verdict: true, + data, + }; + } + + // Prepare request headers + const headers: Record = { + 'Content-Type': 'application/json', + Accept: 'application/json', + Authorization: `Bearer ${parameters.credentials?.apiKey}`, + 'TMV1-Application-Name': applicationName, + }; + + // Set Prefer header + const preferValue = parameters.prefer || 'return=minimal'; + headers['Prefer'] = preferValue; + + const requestOptions = { headers }; + + // Prepare request payload for applyGuardrails endpoint + const request = { + prompt: text, + }; + + let response; + try { + response = await post( + parameters.credentials?.v1Url, + request, + requestOptions, + parameters.timeout + ); + } catch (e) { + if (e instanceof HttpError) { + error = { + message: `API request failed: ${e.message}. body: ${e.response.body}`, + }; + } else { + error = e as Error; + } + } + + if (response) { + data = response; + + if (response.action && response.action === 'Block') { + verdict = false; + } + } + + return { + error, + verdict, + data, + }; +}; diff --git a/plugins/trend-ai/manifest.json b/plugins/trend-ai/manifest.json new file mode 100644 index 000000000..622c8fc1e --- /dev/null +++ b/plugins/trend-ai/manifest.json @@ -0,0 +1,50 @@ +{ + "id": "trend-ai", + "description": "Trend AI Guard for scanning LLM inputs and outputs", + "credentials": { + "type": "object", + "properties": { + "v1ApiKey": { + "type": "string", + "label": "Trend AI Token", + "description": "Trend AI Guard token Get setup here (https://docs.trendmicro.com/en-us/documentation/article/trend-vision-one-ai-scanner-ai-guard)", + "encrypted": true + }, + "v1Url": { + "type": "string", + "label": "Trend AI URL", + "description": "Trend AI Guard URL (e.g., https://api.xdr.trendmicro.com/v3.0/aiSecurity/applyGuardrails)" + } + }, + "required": ["v1Url", "apiKey"] + }, + "functions": [ + { + "name": "Trend AI Guard for scanning LLM inputs and outputs", + "id": "aiGuard", + "supportedHooks": ["beforeRequestHook", "afterRequestHook"], + "type": "guardrail", + "description": [ + { + "type": "subHeading", + "text": "Analyze and scan text for security threats and policy violations using Trend AI Guard services." + } + ], + "parameters": { + "prefer": { + "type": "string", + "label": "Response Detail Level", + "description": "Controls the level of detail in the response. 'return=representation' returns detailed response with scanner results, 'return=minimal' returns short response with only action and reasons.", + "enum": ["return=representation", "return=minimal"], + "default": "return=minimal" + }, + "applicationName": { + "type": "string", + "label": "Application Name", + "description": "The name of the AI application whose prompts are being evaluated. Must contain only letters, numbers, hyphens, and underscores." + }, + "required": ["applicationName"] + } + } + ] +} diff --git a/plugins/trend-ai/trendai.test.ts b/plugins/trend-ai/trendai.test.ts new file mode 100644 index 000000000..b82c65b73 --- /dev/null +++ b/plugins/trend-ai/trendai.test.ts @@ -0,0 +1,458 @@ +import { handler as textGuardHandler } from './guard'; +import { HookEventType, PluginContext } from '../types'; + +const mockLogger = { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), +}; + +const options = { + env: { + ANOTHER_KEY: 'another-value', + }, +}; + +const mockCredentials = { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'jwt-token', +}; + +describe('TrendMicro textGuardHandler', () => { + it('should return an error if v1Url is not provided', async () => { + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: 'This is a test message' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + v1_api_key: 'jwt-token', + applicationName: 'test-app', + }; + + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error?.message).toBe( + "'parameters.credentials.v1Url' must be set" + ); + expect(result.verdict).toBe(true); + expect(result.data).toBeNull(); + }); + + it('should return an error if apiKey is not provided', async () => { + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: 'This is a test message' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + }, + applicationName: 'test-app', + }; + + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error?.message).toBe( + "'parameters.credentials.apiKey' must be set" + ); + expect(result.verdict).toBe(true); + expect(result.data).toBeNull(); + }); + + it('should return an error if text is empty', async () => { + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: '' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'jwt-token', + }, + applicationName: 'test-app', + }; + + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error?.message).toBe('request or response text is empty'); + expect(result.verdict).toBe(true); + expect(result.data).toBeNull(); + }); + + it('should return an error if applicationName is not provided', async () => { + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: 'This is a test message' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'jwt-token', + }, + }; + + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error?.message).toBe('Application name is required'); + expect(result.verdict).toBe(true); + expect(result.data).toBeNull(); + }); + + it('should return an error if applicationName has invalid format', async () => { + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: 'This is a test message' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'jwt-token', + }, + applicationName: 'invalid app name with spaces', + }; + + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error?.message).toBe( + 'Application name must contain only letters, numbers, hyphens, and underscores' + ); + expect(result.verdict).toBe(true); + expect(result.data).toBeNull(); + }); + + it('should handle HTTP errors gracefully', async () => { + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: 'This is a test message' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://invalid-url-that-will-fail.com/v1/scan', + apiKey: 'TRENDMICRO_API_KEY', + }, + applicationName: 'test-app', + timeout: 1000, + }; + + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error).toBeDefined(); + expect(typeof result.error?.message).toBe('string'); + expect(result.error?.message).toContain('fetch failed'); + expect(result.verdict).toBe(true); + expect(result.data).toBeNull(); + }); + + it('should work with afterRequestHook event type', async () => { + const context: PluginContext = { + response: { + json: { + choices: [ + { + message: { + role: 'assistant', + content: 'This is a response message', + }, + }, + ], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'afterRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'TRENDMICRO_API_KEY', + }, + applicationName: 'test-app', + }; + + // Since this will fail due to invalid URL, we just check that it processes the afterRequestHook + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error).toBeDefined(); + expect(result.verdict).toBe(true); + expect(result.data).toBeNull(); + }); + + it('should use correct headers and request format', async () => { + // Mock fetch to verify the request format + const originalFetch = global.fetch; + const mockFetch = jest.fn(); + global.fetch = mockFetch; + + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + status: 'success', + threat_detected: false, + message: 'Text is clean', + }), + }); + + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: 'This is a clean test message' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'jwt-token', + }, + applicationName: 'test-app', + }; + + await textGuardHandler(context, parameters, eventType, options); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.trendmicro.com/v1/scan', + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + 'Content-Type': 'application/json', + Authorization: 'Bearer jwt-token', + 'TMV1-Application-Name': 'test-app', + Prefer: 'return=minimal', + }), + body: expect.stringContaining( + '"prompt":"This is a clean test message"' + ), + }) + ); + + // Restore original fetch + global.fetch = originalFetch; + }); + + it('should return false verdict when threat is detected', async () => { + // Mock fetch to simulate threat detection + const originalFetch = global.fetch; + const mockFetch = jest.fn(); + global.fetch = mockFetch; + + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + action: 'Block', + reason: 'Prompt Attack Detected', + }), + }); + + const context: PluginContext = { + request: { + json: { + messages: [ + { role: 'user', content: 'This is a malicious test message' }, + ], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'TRENDMICRO_API_KEY', + }, + applicationName: 'test-app', + }; + + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(false); + expect(result.data).toBeDefined(); + + // Restore original fetch + global.fetch = originalFetch; + }); + + it('should return true verdict when no threat is detected', async () => { + // Mock fetch to simulate clean text + const originalFetch = global.fetch; + const mockFetch = jest.fn(); + global.fetch = mockFetch; + + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + status: 'success', + threat_detected: false, + message: 'Text is clean', + }), + }); + + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: 'This is a clean test message' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'TRENDMICRO_API_KEY', + }, + applicationName: 'test-app', + }; + + const result = await textGuardHandler( + context, + parameters, + eventType, + options + ); + + expect(result.error).toBeNull(); + expect(result.verdict).toBe(true); + expect(result.data).toBeDefined(); + expect(result.data.threat_detected).toBe(false); + + // Restore original fetch + global.fetch = originalFetch; + }); + + it('should use correct Prefer header when prefer parameter is set', async () => { + // Mock fetch to verify the request format + const originalFetch = global.fetch; + const mockFetch = jest.fn(); + global.fetch = mockFetch; + + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + id: '1234567890abcdef', + action: 'Allow', + reasons: [], + }), + }); + + const context: PluginContext = { + request: { + json: { + messages: [{ role: 'user', content: 'This is a test message' }], + }, + }, + requestType: 'chatComplete', + logger: mockLogger, + }; + const eventType: HookEventType = 'beforeRequestHook'; + const parameters = { + credentials: { + v1Url: 'https://api.trendmicro.com/v1/scan', + apiKey: 'jwt-token', + }, + applicationName: 'test-app', + prefer: 'return=representation', + }; + + await textGuardHandler(context, parameters, eventType, options); + + expect(mockFetch).toHaveBeenCalledWith( + 'https://api.trendmicro.com/v1/scan', + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + 'Content-Type': 'application/json', + Authorization: 'Bearer jwt-token', + 'TMV1-Application-Name': 'test-app', + Prefer: 'return=representation', + }), + body: expect.stringContaining('"prompt":"This is a test message"'), + }) + ); + + // Restore original fetch + global.fetch = originalFetch; + }); +}); diff --git a/plugins/trend-ai/version.ts b/plugins/trend-ai/version.ts new file mode 100644 index 000000000..a9d12425a --- /dev/null +++ b/plugins/trend-ai/version.ts @@ -0,0 +1 @@ +export const VERSION = 'v1.0.0';