diff --git a/src/McpContext.ts b/src/McpContext.ts index 11bb3d971..ff31d1c0e 100644 --- a/src/McpContext.ts +++ b/src/McpContext.ts @@ -58,6 +58,8 @@ interface McpContextOptions { experimentalDevToolsDebugging: boolean; // Whether all page-like targets are exposed as pages. experimentalIncludeAllPages?: boolean; + // Custom headers to add to all network requests made by the browser. + headers?: Record; } const DEFAULT_TIMEOUT = 5_000; @@ -104,6 +106,9 @@ export class McpContext implements Context { #textSnapshot: TextSnapshot | null = null; #networkCollector: NetworkCollector; #consoleCollector: ConsoleCollector; + + // Custom headers to add to all network requests made by the browser. + #headers?: Record; #isRunningTrace = false; #networkConditionsMap = new WeakMap(); @@ -127,8 +132,11 @@ export class McpContext implements Context { this.logger = logger; this.#locatorClass = locatorClass; this.#options = options; + this.#headers = options.headers; - this.#networkCollector = new NetworkCollector(this.browser); + this.#networkCollector = new NetworkCollector(this.browser, undefined, { + headers: this.#headers, + }); this.#consoleCollector = new ConsoleCollector(this.browser, collect => { return { @@ -675,6 +683,8 @@ export class McpContext implements Context { collect(req); }, } as ListenerMap; + }, { + headers: this.#headers, }); await this.#networkCollector.init(await this.browser.pages()); } diff --git a/src/PageCollector.ts b/src/PageCollector.ts index 323d7fdbe..d89a74ad3 100644 --- a/src/PageCollector.ts +++ b/src/PageCollector.ts @@ -349,6 +349,8 @@ class PageIssueSubscriber { } export class NetworkCollector extends PageCollector { + #headers?: Record; + constructor( browser: Browser, listeners: ( @@ -360,8 +362,36 @@ export class NetworkCollector extends PageCollector { }, } as ListenerMap; }, + options?: Record & { + headers?: Record + } ) { super(browser, listeners); + if (options?.headers) { + this.#headers = options?.headers; + } + } + + override async init(pages: Page[]): Promise { + for (const page of pages) { + await this.#applyHeadersToPage(page); + } + await super.init(pages); + } + + override addPage(page: Page): void { + super.addPage(page); + void this.#applyHeadersToPage(page); + } + + async #applyHeadersToPage(page: Page): Promise { + if (this.#headers) { + try { + await page.setExtraHTTPHeaders(this.#headers); + } catch (error) { + logger('Error applying headers to page:', error); + } + } } override splitAfterNavigation(page: Page) { const navigations = this.storage.get(page) ?? []; diff --git a/src/cli.ts b/src/cli.ts index db2680587..26d65e6a9 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -87,6 +87,27 @@ export const cliOptions = { } }, }, + headers: { + type: 'string', + description: + 'Custom headers to add to all network requests made by the browser in JSON format (e.g., \'{"x-env":"visit_from_mcp","x-mock-user":"mcp"}\').', + coerce: (val: string | undefined) => { + if (!val) { + return; + } + try { + const parsed = JSON.parse(val); + if (typeof parsed !== 'object' || Array.isArray(parsed)) { + throw new Error('Headers must be a JSON object'); + } + return parsed as Record; + } catch (error) { + throw new Error( + `Invalid JSON for headers: ${(error as Error).message}`, + ); + } + }, + }, headless: { type: 'boolean', description: 'Whether to run in headless (no UI) mode.', diff --git a/src/main.ts b/src/main.ts index 84bb6d9b5..28649d629 100644 --- a/src/main.ts +++ b/src/main.ts @@ -85,9 +85,10 @@ async function getContext(): Promise { if (context?.browser !== browser) { context = await McpContext.from(browser, logger, { - experimentalDevToolsDebugging: devtools, - experimentalIncludeAllPages: args.experimentalIncludeAllPages, - }); + experimentalDevToolsDebugging: devtools, + experimentalIncludeAllPages: args.experimentalIncludeAllPages, + headers: args.headers, + }); } return context; } diff --git a/tests/McpContext.test.ts b/tests/McpContext.test.ts index e1c34f52c..78cc4c8dd 100644 --- a/tests/McpContext.test.ts +++ b/tests/McpContext.test.ts @@ -102,3 +102,22 @@ describe('McpContext', () => { ); }); }); + +describe('McpContext headers functionality', () => { + it('works with headers in context options', async () => { + await withMcpContext(async (_response, context) => { + const page = context.getSelectedPage(); + await page.setContent('Test page'); + + // Verify context was created successfully + assert.ok(context); + + // Test that we can make a request (headers should be applied if any) + const navigationPromise = page.goto('data:text/html,Test'); + await navigationPromise; + + // If we reach here without errors, headers functionality is working + assert.ok(true); + }, { debug: false }); + }); +}); diff --git a/tests/PageCollector.test.ts b/tests/PageCollector.test.ts index 41e769c47..c61f6f257 100644 --- a/tests/PageCollector.test.ts +++ b/tests/PageCollector.test.ts @@ -284,6 +284,41 @@ describe('NetworkCollector', () => { page.emit('request', request); assert.equal(collector.getData(page, true).length, 3); }); + + it('works with extra headers', async () => { + const browser = getMockBrowser(); + const page = (await browser.pages())[0]; + + let setExtraHTTPHeadersCalled = 0; + let setExtraHTTPHeadersArgs = null; + + page.setExtraHTTPHeaders = async (headers) => { + setExtraHTTPHeadersCalled++; + setExtraHTTPHeadersArgs = headers; + return Promise.resolve(); + }; + + const collector = new NetworkCollector(browser, collect => { + return { + request: req => { + collect(req); + }, + } as ListenerMap; + }, { + headers: { + "x-env": "test_mcp", + "x-user": "mock_user" + } + }); + + await collector.init([page]); + + assert.equal(setExtraHTTPHeadersCalled > 0, true, 'page.setExtraHTTPHeaders should be called'); + assert.deepEqual(setExtraHTTPHeadersArgs, { + "x-env": "test_mcp", + "x-user": "mock_user" + }, 'should set extra headers'); + }); }); describe('ConsoleCollector', () => { diff --git a/tests/cli.test.ts b/tests/cli.test.ts index 11e93a3de..c11d42ba3 100644 --- a/tests/cli.test.ts +++ b/tests/cli.test.ts @@ -222,4 +222,45 @@ describe('cli args parsing', () => { autoConnect: true, }); }); + + it('parses headers with valid JSON', async () => { + const args = parseArguments('1.0.0', [ + 'node', + 'main.js', + '--headers', + '{"x-env":"visit_from_mcp","x-mock-user":"mcp"}', + ]); + assert.deepStrictEqual(args.headers, { + 'x-env': 'visit_from_mcp', + 'x-mock-user': 'mcp', + }); + }); + + it('throws error for invalid headers JSON', async () => { + assert.throws( + () => { + parseArguments('1.0.0', [ + 'node', + 'main.js', + '--headers', + '{"invalid": json}', + ]); + }, + /Invalid JSON for headers/ + ); + }); + + it('throws error for non-object headers', async () => { + assert.throws( + () => { + parseArguments('1.0.0', [ + 'node', + 'main.js', + '--headers', + '["array", "of", "headers"]', + ]); + }, + /Headers must be a JSON object/ + ); + }); });