diff --git a/src/browserServerBackend.ts b/src/browserServerBackend.ts index 3a5ada4c8..a60449897 100644 --- a/src/browserServerBackend.ts +++ b/src/browserServerBackend.ts @@ -17,7 +17,7 @@ import { fileURLToPath } from 'url'; import { z } from 'zod'; import { FullConfig } from './config.js'; -import { Context } from './context.js'; +import { Context, ContextOptions } from './context.js'; import { logUnhandledError } from './log.js'; import { Response } from './response.js'; import { SessionLog } from './sessionLog.js'; @@ -25,9 +25,9 @@ import { filteredTools } from './tools.js'; import { packageJSON } from './package.js'; import { defineTool } from './tools/tool.js'; +import * as mcpServer from './mcp/server.js'; import type { Tool } from './tools/tool.js'; import type { BrowserContextFactory } from './browserContextFactory.js'; -import type * as mcpServer from './mcp/server.js'; import type { ServerBackend } from './mcp/server.js'; type NonEmptyArray = [T, ...T[]]; @@ -43,33 +43,28 @@ export class BrowserServerBackend implements ServerBackend { private _sessionLog: SessionLog | undefined; private _config: FullConfig; private _browserContextFactory: BrowserContextFactory; + private _browserContextFactories: FactoryList; + + onChangeProxyTarget: ServerBackend['onChangeProxyTarget']; constructor(config: FullConfig, factories: FactoryList) { this._config = config; this._browserContextFactory = factories[0]; + this._browserContextFactories = factories; this._tools = filteredTools(config); - if (factories.length > 1) - this._tools.push(this._defineContextSwitchTool(factories)); } - async initialize(server: mcpServer.Server): Promise { - const capabilities = server.getClientCapabilities() as mcpServer.ClientCapabilities; - let rootPath: string | undefined; - if (capabilities.roots && ( - server.getClientVersion()?.name === 'Visual Studio Code' || - server.getClientVersion()?.name === 'Visual Studio Code - Insiders')) { - const { roots } = await server.listRoots(); - const firstRootUri = roots[0]?.uri; - const url = firstRootUri ? new URL(firstRootUri) : undefined; - rootPath = url ? fileURLToPath(url) : undefined; - } + async initialize(info: mcpServer.InitializeInfo): Promise { + this._defineContextSwitchTool(mcpServer.isVSCode(info.clientVersion)); + + const rootPath = info.roots?.roots[0]?.uri ? fileURLToPath(new URL(info.roots.roots[0].uri)) : undefined; this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config, rootPath) : undefined; this._context = new Context({ tools: this._tools, config: this._config, browserContextFactory: this._browserContextFactory, sessionLog: this._sessionLog, - clientInfo: { ...server.getClientVersion(), rootPath }, + clientInfo: { ...info.clientVersion, rootPath }, }); } @@ -98,9 +93,35 @@ export class BrowserServerBackend implements ServerBackend { void this._context!.dispose().catch(logUnhandledError); } - private _defineContextSwitchTool(factories: FactoryList): Tool { - const self = this; - return defineTool({ + private _defineContextSwitchTool(isVSCode: boolean) { + const contextSwitchers: { name: string, description?: string, switch(options: any): Promise }[] = []; + for (const factory of this._browserContextFactories) { + contextSwitchers.push({ + name: factory.name, + description: factory.description, + switch: async () => { + await this._setContextFactory(factory); + } + }); + } + + const askForOptions = isVSCode; + if (isVSCode) { + contextSwitchers.push({ + name: 'vscode', + switch: async (options: any) => { + if (!options.connectionString || !options.lib) + this.onChangeProxyTarget?.('', {}); + else + this.onChangeProxyTarget?.('vscode', options); + } + }); + } + + if (contextSwitchers.length < 2) + return; + + this._tools.push(defineTool({ capability: 'core', schema: { @@ -108,29 +129,31 @@ export class BrowserServerBackend implements ServerBackend { title: 'Connect to a browser context', description: [ 'Connect to a browser using one of the available methods:', - ...factories.map(factory => `- "${factory.name}": ${factory.description}`), + ...contextSwitchers.filter(c => c.description).map(c => `- "${c.name}": ${c.description}`), + `By default, you're connected to the first method. Only call this tool to change it.`, ].join('\n'), inputSchema: z.object({ - method: z.enum(factories.map(factory => factory.name) as [string, ...string[]]).default(factories[0].name).describe('The method to use to connect to the browser'), + method: z.enum(contextSwitchers.map(c => c.name) as [string, ...string[]]).describe('The method to use to connect to the browser'), + options: askForOptions ? z.any().optional().describe('options for the connection method') : z.void(), }), type: 'readOnly', }, async handle(context, params, response) { - const factory = factories.find(factory => factory.name === params.method); - if (!factory) { + const contextSwitcher = contextSwitchers.find(c => c.name === params.method); + if (!contextSwitcher) { response.addError('Unknown connection method: ' + params.method); return; } - await self._setContextFactory(factory); + await contextSwitcher.switch(params.options); response.addResult('Successfully changed connection method.'); } - }); + })); } private async _setContextFactory(newFactory: BrowserContextFactory) { if (this._context) { - const options = { + const options: ContextOptions = { ...this._context.options, browserContextFactory: newFactory, }; diff --git a/src/context.ts b/src/context.ts index e84356d7e..18e531658 100644 --- a/src/context.ts +++ b/src/context.ts @@ -29,7 +29,7 @@ import type { SessionLog } from './sessionLog.js'; const testDebug = debug('pw:mcp:test'); -type ContextOptions = { +export type ContextOptions = { tools: Tool[]; config: FullConfig; browserContextFactory: BrowserContextFactory; @@ -210,6 +210,10 @@ export class Context { for (const page of browserContext.pages()) this._onPageCreated(page); browserContext.on('page', page => this._onPageCreated(page)); + browserContext.on('close', () => { + this._browserContextPromise = undefined; + this._closeBrowserContextPromise = undefined; + }); if (this.config.saveTrace) { await browserContext.tracing.start({ name: 'trace', diff --git a/src/mcp/server.ts b/src/mcp/server.ts index 18c31447d..1aa8365f4 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -21,7 +21,7 @@ import { zodToJsonSchema } from 'zod-to-json-schema'; import { ManualPromise } from '../manualPromise.js'; import { logUnhandledError } from '../log.js'; -import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js'; +import type { ImageContent, Implementation, ListRootsResult, TextContent } from '@modelcontextprotocol/sdk/types.js'; import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; export type { Server } from '@modelcontextprotocol/sdk/server/index.js'; @@ -46,17 +46,80 @@ export type ToolSchema = { export type ToolHandler = (toolName: string, params: any) => Promise; +export interface InitializeInfo { + clientVersion: Implementation; + roots?: ListRootsResult; +} + export interface ServerBackend { name: string; version: string; - initialize?(server: Server): Promise; + initialize?(info: InitializeInfo): Promise; tools(): ToolSchema[]; callTool(schema: ToolSchema, parsedArguments: any): Promise; serverClosed?(): void; + + onChangeProxyTarget?: (target: string, options: any) => void; } export type ServerBackendFactory = () => ServerBackend; +export class ServerBackendSwitcher implements ServerBackend { + private _target: ServerBackend; + private _initializeInfo?: InitializeInfo; + + constructor(private readonly _targetFactories: Record ServerBackend>) { + const defaultTargetFactory = this._targetFactories['']; + this._target = defaultTargetFactory({}); + this._target.onChangeProxyTarget = this._handleChangeProxyTarget; + } + + _handleChangeProxyTarget = (name: string, options: any) => { + const factory = this._targetFactories[name]; + if (!factory) + throw new Error(`Unknown target: ${name}`); + const target = factory(options); + this.switch(target).catch(logUnhandledError); + }; + + async switch(target: ServerBackend) { + const old = this._target; + old.onChangeProxyTarget = undefined; + this._target = target; + this._target.onChangeProxyTarget = this._handleChangeProxyTarget; + if (this._initializeInfo) { + old.serverClosed?.(); + await this.initialize(this._initializeInfo); + } + } + + get name() { + return this._target.name; + } + + get version() { + return this._target.version; + } + + async initialize(info: InitializeInfo): Promise { + this._initializeInfo = info; + await this._target.initialize?.(info); + } + + tools(): ToolSchema[] { + return this._target.tools(); + } + + async callTool(schema: ToolSchema, parsedArguments: any): Promise { + return this._target.callTool(schema, parsedArguments); + } + + serverClosed(): void { + this._target.serverClosed?.(); + this._initializeInfo = undefined; + } +} + export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) { const backend = serverBackendFactory(); const server = createServer(backend, runHeartbeat); @@ -71,9 +134,9 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser } }); - const tools = backend.tools(); server.setRequestHandler(ListToolsRequestSchema, async () => { - return { tools: tools.map(tool => ({ + await initializedPromise; + return { tools: backend.tools().map(tool => ({ name: tool.name, description: tool.description, inputSchema: zodToJsonSchema(tool.inputSchema), @@ -99,7 +162,7 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser content: [{ type: 'text', text: '### Result\n' + messages.join('\n') }], isError: true, }); - const tool = tools.find(tool => tool.name === request.params.name) as ToolSchema; + const tool = backend.tools().find(tool => tool.name === request.params.name) as ToolSchema; if (!tool) return errorResult(`Error: Tool "${request.params.name}" not found`); @@ -109,13 +172,32 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser return errorResult(String(error)); } }); - addServerListener(server, 'initialized', () => { - backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError); + addServerListener(server, 'initialized', async () => { + try { + const info = await getInitializeInfo(server); + await backend.initialize?.(info); + initializedPromise.resolve(); + } catch (e) { + logUnhandledError(e); + } }); addServerListener(server, 'close', () => backend.serverClosed?.()); return server; } +async function getInitializeInfo(server: Server) { + const info: InitializeInfo = { + clientVersion: server.getClientVersion()!, + }; + if (server.getClientCapabilities()?.roots?.listRoots && isVSCode(info.clientVersion)) + info.roots = await server.listRoots(); + return info; +} + +export function isVSCode(clientVersion: Implementation): boolean { + return clientVersion.name === 'Visual Studio Code' || clientVersion.name === 'Visual Studio Code - Insiders'; +} + const startHeartbeat = (server: Server) => { const beat = () => { Promise.race([ diff --git a/src/program.ts b/src/program.ts index ae1f1d934..223fb1b50 100644 --- a/src/program.ts +++ b/src/program.ts @@ -26,6 +26,8 @@ import { BrowserServerBackend, FactoryList } from './browserServerBackend.js'; import { Context } from './context.js'; import { contextFactory } from './browserContextFactory.js'; import { runLoopTools } from './loopTools/main.js'; +import { ServerBackendSwitcher } from './mcp/server.js'; +import { VSCodeServerBackend } from './vscode/vscodeHost.js'; program .version('Version ' + packageJSON.version) @@ -82,7 +84,10 @@ program const factories: FactoryList = [browserContextFactory]; if (options.connectTool) factories.push(createExtensionContextFactory(config)); - const serverBackendFactory = () => new BrowserServerBackend(config, factories); + const serverBackendFactory = () => new ServerBackendSwitcher({ + '': () => new BrowserServerBackend(config, factories), + 'vscode': options => new VSCodeServerBackend(config, options.connectionString, options.lib), + }); await mcpTransport.start(serverBackendFactory, config.server); if (config.saveTrace) { diff --git a/src/tab.ts b/src/tab.ts index d7f44fa51..674b2c42f 100644 --- a/src/tab.ts +++ b/src/tab.ts @@ -74,7 +74,7 @@ export class Tab extends EventEmitter { }); page.on('dialog', dialog => this._dialogShown(dialog)); page.on('download', download => { - void this._downloadStarted(download); + void this._downloadStarted(download).catch(logUnhandledError); }); page.setDefaultNavigationTimeout(60000); page.setDefaultTimeout(5000); diff --git a/src/vscode/process.ts b/src/vscode/process.ts new file mode 100644 index 000000000..973558b50 --- /dev/null +++ b/src/vscode/process.ts @@ -0,0 +1,128 @@ +/** + * Copyright Microsoft Corporation. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { createRequire } from 'node:module'; + +const require = createRequire(import.meta.url); + +interface SerializedError { + name: string; + message: string; + stack?: string; +} + +function serializeError(error: Error | any): SerializedError { + if (error instanceof Error) { + return { + name: error.name, + message: error.message, + stack: error.stack, + }; + } + + return { + name: 'Error', + message: '' + error + }; +} + +export type ProtocolRequest = { + id: number; + method: string; + params?: any; +}; + +export type ProtocolResponse = { + id?: number; + error?: SerializedError; + method?: string; + params?: any; + result?: any; +}; + +export class ProcessRunner { + async gracefullyClose(): Promise { } + + protected dispatchEvent(method: string, params: any) { + const response: ProtocolResponse = { method, params }; + sendMessageToParent({ method: '__dispatch__', params: response }); + } +} + +let gracefullyCloseCalled = false; +let forceExitInitiated = false; + +sendMessageToParent({ method: 'ready' }); + +process.on('disconnect', () => gracefullyCloseAndExit(true)); +process.on('SIGINT', () => {}); +process.on('SIGTERM', () => {}); + +let processRunner: ProcessRunner | undefined; + +process.on('message', async (message: any) => { + if (message.method === '__init__') { + const { runnerParams, runnerScript } = message.params as { runnerParams: any, runnerScript: string }; + const { create } = require(runnerScript); + processRunner = create(runnerParams) as ProcessRunner; + return; + } + if (message.method === '__stop__') { + await gracefullyCloseAndExit(false); + return; + } + if (message.method === '__dispatch__') { + const { id, method, params } = message.params as ProtocolRequest; + try { + const result = await (processRunner as any)[method](params); + const response: ProtocolResponse = { id, result }; + sendMessageToParent({ method: '__dispatch__', params: response }); + } catch (e) { + const response: ProtocolResponse = { id, error: serializeError(e) }; + sendMessageToParent({ method: '__dispatch__', params: response }); + } + } +}); + +const kForceExitTimeout = 30000; + +async function gracefullyCloseAndExit(forceExit: boolean) { + if (forceExit && !forceExitInitiated) { + forceExitInitiated = true; + // Force exit after 30 seconds. + setTimeout(() => process.exit(0), kForceExitTimeout); + } + if (!gracefullyCloseCalled) { + gracefullyCloseCalled = true; + // Meanwhile, try to gracefully shutdown. + await processRunner?.gracefullyClose().catch(() => {}); + process.exit(0); + } +} + +function sendMessageToParent(message: { method: string, params?: any }) { + try { + process.send!(message); + } catch (e) { + try { + // By default, the IPC messages are serialized as JSON. + JSON.stringify(message); + } catch { + // Always throw serialization errors. + throw e; + } + // Can throw when closing. + } +} diff --git a/src/vscode/processHost.ts b/src/vscode/processHost.ts new file mode 100644 index 000000000..4471d4bc6 --- /dev/null +++ b/src/vscode/processHost.ts @@ -0,0 +1,143 @@ +/** + * Copyright Microsoft Corporation. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import child_process from 'child_process'; +import { assert } from 'console'; +import { EventEmitter } from 'events'; +import { ProtocolResponse } from './process.js'; + +export type ProcessExitData = { + unexpectedly: boolean; + code: number | null; + signal: NodeJS.Signals | null; +}; + +export class ProcessHost extends EventEmitter { + private process: child_process.ChildProcess | undefined; + private _didSendStop = false; + private _processDidExit = false; + private _didExitAndRanOnExit = false; + private _runnerScript: string; + private _lastMessageId = 0; + private _callbacks = new Map void, reject: (error: Error) => void }>(); + private _extraEnv: Record; + + constructor(runnerScript: string, env: Record) { + super(); + this._runnerScript = runnerScript; + this._extraEnv = env; + } + + async startRunner(runnerParams: any, options: { onStdOut?: (chunk: Buffer | string) => void, onStdErr?: (chunk: Buffer | string) => void } = {}): Promise { + assert(!this.process, 'Internal error: starting the same process twice'); + this.process = child_process.fork(new URL('./process.js', import.meta.url), { + detached: false, + env: { + ...process.env, + ...this._extraEnv, + }, + stdio: [ + 'ignore', + options.onStdOut ? 'pipe' : 'inherit', + options.onStdErr ? 'pipe' : 'inherit', + 'ipc', + ], + }); + this.process.on('exit', async (code, signal) => { + this._processDidExit = true; + await this.onExit(); + this._didExitAndRanOnExit = true; + this.emit('exit', { unexpectedly: !this._didSendStop, code, signal } as ProcessExitData); + }); + this.process.on('error', e => {}); // do not yell at a send to dead process. + this.process.on('message', (message: any) => { + if (message.method === '__dispatch__') { + const { id, error, method, params, result } = message.params as ProtocolResponse; + if (id && this._callbacks.has(id)) { + const { resolve, reject } = this._callbacks.get(id)!; + this._callbacks.delete(id); + if (error) { + const errorObject = new Error(error.message); + errorObject.stack = error.stack; + reject(errorObject); + } else { + resolve(result); + } + } else { + this.emit(method!, params); + } + } else { + this.emit(message.method!, message.params); + } + }); + + if (options.onStdOut) + this.process.stdout?.on('data', options.onStdOut); + if (options.onStdErr) + this.process.stderr?.on('data', options.onStdErr); + + const error = await new Promise(resolve => { + this.process!.once('exit', (code, signal) => resolve({ unexpectedly: true, code, signal })); + this.once('ready', () => resolve(undefined)); + }); + + if (error) + return error; + + this.send({ + method: '__init__', + params: { + runnerScript: this._runnerScript, + runnerParams + } + }); + } + + sendMessage(message: { method: string, params?: any }) { + const id = ++this._lastMessageId; + this.send({ + method: '__dispatch__', + params: { id, ...message } + }); + return new Promise((resolve, reject) => { + this._callbacks.set(id, { resolve, reject }); + }); + } + + protected sendMessageNoReply(message: { method: string, params?: any }) { + this.sendMessage(message).catch(() => {}); + } + + protected async onExit() { + } + + async stop() { + if (!this._processDidExit && !this._didSendStop) { + this.send({ method: '__stop__' }); + this._didSendStop = true; + } + if (!this._didExitAndRanOnExit) + await new Promise(f => this.once('exit', f)); + } + + didSendStop() { + return this._didSendStop; + } + + private send(message: { method: string, params?: any }) { + this.process?.send(message); + } +} diff --git a/src/vscode/vscodeHost.ts b/src/vscode/vscodeHost.ts new file mode 100644 index 000000000..cdc76cb96 --- /dev/null +++ b/src/vscode/vscodeHost.ts @@ -0,0 +1,68 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { FullConfig } from '../config.js'; +import { InitializeInfo, ServerBackend, ToolSchema } from '../mcp/server.js'; +import { packageJSON } from '../package.js'; +import { filteredTools } from '../tools.js'; +import { ProcessHost } from './processHost.js'; +import { VSCodeInitParams } from './vscodeMain.js'; + +export class VSCodeServerBackend extends ProcessHost implements ServerBackend { + readonly name = 'Playwright'; + readonly version = packageJSON.version; + + onChangeProxyTarget?: (target: string, options: any) => void; + + constructor(private _config: FullConfig, private _connectionString: string, private _lib: string) { + super(new URL('./vscodeMain.js', import.meta.url).pathname, {}); + this.on('changeProxyTarget', ({ target, options }) => this.onChangeProxyTarget?.(target, options)); + } + + async initialize(info: InitializeInfo) { + const params: VSCodeInitParams = { + config: this._config, + connectionString: this._connectionString, + lib: this._lib, + }; + const error = await this.startRunner(params); + if (error) + throw error; + + await this.sendMessage({ + method: 'initialize', + params: info + }); + } + + tools(): ToolSchema[] { + return filteredTools(this._config).map(tool => tool.schema); + } + + serverClosed?() { + void this.stop(); + } + + async callTool(schema: ToolSchema, parsedArguments: any) { + const response = await this.sendMessage({ + method: 'callTool', + params: { + toolName: schema.name, + parsedArguments, + }, + }); + return response as any; + } +} diff --git a/src/vscode/vscodeMain.ts b/src/vscode/vscodeMain.ts new file mode 100644 index 000000000..756e674f6 --- /dev/null +++ b/src/vscode/vscodeMain.ts @@ -0,0 +1,101 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BrowserServerBackend } from '../browserServerBackend.js'; +import { FullConfig } from '../config.js'; +import { logUnhandledError } from '../log.js'; +import { ProcessRunner } from './process.js'; +import { InitializeInfo } from '../mcp/server.js'; +import type { BrowserContextFactory, ClientInfo } from '../browserContextFactory.js'; +import type * as playwright from 'playwright'; + +export interface VSCodeInitParams { + config: FullConfig; + connectionString: string; + lib: string; +} + +export class VSCodeMain extends ProcessRunner { + private _backend: BrowserServerBackend; + constructor(params: VSCodeInitParams) { + super(); + const factory = new VSCodeContextFactory(params.config, params.connectionString, params.lib); + this._backend = new BrowserServerBackend(params.config, [factory]); + this._backend.onChangeProxyTarget = (target, options) => this.dispatchEvent('changeProxyTarget', { target, options }); + } + + async gracefullyClose() { + this._backend.serverClosed?.(); + } + + async callTool({ toolName, parsedArguments }: { toolName: string; parsedArguments: any; }): Promise { + const tool = this._backend.tools().find(tool => tool.name === toolName); + return await this._backend.callTool(tool!, parsedArguments); + } + + async initialize(params: InitializeInfo) { + await this._backend.initialize(params); + } +} + +export const create = (params: VSCodeInitParams) => new VSCodeMain(params); + +/** + * turns the operating system's "Close" button into UI for closing the browser. + * the user can use it to dismiss the foreground browser window, and the browser will be closed. + */ +function closeOnUIClose(context: playwright.BrowserContext) { + context.on('close', () => context.browser()?.close({ reason: 'ui closed' }).catch(logUnhandledError)); + context.on('page', page => { + page.on('close', () => { + if (context.pages().length === 0) + void context.close().catch(logUnhandledError); + }); + }); +} + +class VSCodeContextFactory implements BrowserContextFactory { + name = 'vscode'; + description = 'Connect to a browser running in the Playwright VS Code extension'; + + constructor(private readonly _config: FullConfig, private readonly _connectionString: string, private readonly _lib: string) {} + + async createContext(clientInfo: ClientInfo, abortSignal: AbortSignal): Promise<{ browserContext: playwright.BrowserContext; close: () => Promise; }> { + // TODO: what's the difference between the abortSignal and the close() retval? + + const connectionString = new URL(this._connectionString); + connectionString.searchParams.set('launch-options', JSON.stringify({ + ...this._config.browser.launchOptions, + ...this._config.browser.contextOptions, + userDataDir: this._config.browser.userDataDir, + })); + const lib = await import(this._lib).then(mod => mod.default ?? mod) as typeof import('playwright'); + const browser = await lib[this._config.browser.browserName].connect(connectionString.toString()); + + const context: playwright.BrowserContext = browser.contexts()[0] ?? await browser.newContext(this._config.browser.contextOptions); + + // when the user closes the browser window, we should reconnect. + closeOnUIClose(context); + + return { + browserContext: context, + close: async () => { + // close the connection. in this mode, the browser will survive + await browser.close(); + } + }; + } +}