Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions src/browserServerBackend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
import { fileURLToPath } from 'url';
import { z } from 'zod';
import { FullConfig } from './config.js';
import { Context } from './context.js';
import { Context, ContextOptions } from './context.js';
import { logUnhandledError } from './log.js';
import { Response } from './response.js';
import { SessionLog } from './sessionLog.js';
import { filteredTools } from './tools.js';
import { packageJSON } from './package.js';
import { defineTool } from './tools/tool.js';

import * as mcpServer from './mcp/server.js';
import type { Tool } from './tools/tool.js';
import type { BrowserContextFactory } from './browserContextFactory.js';
import type * as mcpServer from './mcp/server.js';
import type { ServerBackend } from './mcp/server.js';

type NonEmptyArray<T> = [T, ...T[]];
Expand All @@ -43,33 +43,28 @@ export class BrowserServerBackend implements ServerBackend {
private _sessionLog: SessionLog | undefined;
private _config: FullConfig;
private _browserContextFactory: BrowserContextFactory;
private _browserContextFactories: FactoryList;

onChangeProxyTarget: ServerBackend['onChangeProxyTarget'];

constructor(config: FullConfig, factories: FactoryList) {
this._config = config;
this._browserContextFactory = factories[0];
this._browserContextFactories = factories;
this._tools = filteredTools(config);
if (factories.length > 1)
this._tools.push(this._defineContextSwitchTool(factories));
}

async initialize(server: mcpServer.Server): Promise<void> {
const capabilities = server.getClientCapabilities() as mcpServer.ClientCapabilities;
let rootPath: string | undefined;
if (capabilities.roots && (
server.getClientVersion()?.name === 'Visual Studio Code' ||
server.getClientVersion()?.name === 'Visual Studio Code - Insiders')) {
const { roots } = await server.listRoots();
const firstRootUri = roots[0]?.uri;
const url = firstRootUri ? new URL(firstRootUri) : undefined;
rootPath = url ? fileURLToPath(url) : undefined;
}
async initialize(info: mcpServer.InitializeInfo): Promise<void> {
this._defineContextSwitchTool(mcpServer.isVSCode(info.clientVersion));

const rootPath = info.roots?.roots[0]?.uri ? fileURLToPath(new URL(info.roots.roots[0].uri)) : undefined;
this._sessionLog = this._config.saveSession ? await SessionLog.create(this._config, rootPath) : undefined;
this._context = new Context({
tools: this._tools,
config: this._config,
browserContextFactory: this._browserContextFactory,
sessionLog: this._sessionLog,
clientInfo: { ...server.getClientVersion(), rootPath },
clientInfo: { ...info.clientVersion, rootPath },
});
}

Expand Down Expand Up @@ -98,39 +93,67 @@ export class BrowserServerBackend implements ServerBackend {
void this._context!.dispose().catch(logUnhandledError);
}

private _defineContextSwitchTool(factories: FactoryList): Tool<any> {
const self = this;
return defineTool({
private _defineContextSwitchTool(isVSCode: boolean) {
const contextSwitchers: { name: string, description?: string, switch(options: any): Promise<void> }[] = [];
for (const factory of this._browserContextFactories) {
contextSwitchers.push({
name: factory.name,
description: factory.description,
switch: async () => {
await this._setContextFactory(factory);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that we now mix them and have 2 ways to switch how we connect to the browser. Maybe migrate the context factories to the backend switchers as well or what is the plan?

}
});
}

const askForOptions = isVSCode;
if (isVSCode) {
contextSwitchers.push({
name: 'vscode',
switch: async (options: any) => {
if (!options.connectionString || !options.lib)
this.onChangeProxyTarget?.('', {});
else
this.onChangeProxyTarget?.('vscode', options);
}
});
}

if (contextSwitchers.length < 2)
return;

this._tools.push(defineTool<any>({
capability: 'core',

schema: {
name: 'browser_connect',
title: 'Connect to a browser context',
description: [
'Connect to a browser using one of the available methods:',
...factories.map(factory => `- "${factory.name}": ${factory.description}`),
...contextSwitchers.filter(c => c.description).map(c => `- "${c.name}": ${c.description}`),
`By default, you're connected to the first method. Only call this tool to change it.`,
].join('\n'),
inputSchema: z.object({
method: z.enum(factories.map(factory => factory.name) as [string, ...string[]]).default(factories[0].name).describe('The method to use to connect to the browser'),
method: z.enum(contextSwitchers.map(c => c.name) as [string, ...string[]]).describe('The method to use to connect to the browser'),
options: askForOptions ? z.any().optional().describe('options for the connection method') : z.void(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not define this parameter if askForOptions is false

}),
type: 'readOnly',
},

async handle(context, params, response) {
const factory = factories.find(factory => factory.name === params.method);
if (!factory) {
const contextSwitcher = contextSwitchers.find(c => c.name === params.method);
if (!contextSwitcher) {
response.addError('Unknown connection method: ' + params.method);
return;
}
await self._setContextFactory(factory);
await contextSwitcher.switch(params.options);
response.addResult('Successfully changed connection method.');
}
});
}));
}

private async _setContextFactory(newFactory: BrowserContextFactory) {
if (this._context) {
const options = {
const options: ContextOptions = {
...this._context.options,
browserContextFactory: newFactory,
};
Expand Down
6 changes: 5 additions & 1 deletion src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import type { SessionLog } from './sessionLog.js';

const testDebug = debug('pw:mcp:test');

type ContextOptions = {
export type ContextOptions = {
tools: Tool[];
config: FullConfig;
browserContextFactory: BrowserContextFactory;
Expand Down Expand Up @@ -210,6 +210,10 @@ export class Context {
for (const page of browserContext.pages())
this._onPageCreated(page);
browserContext.on('page', page => this._onPageCreated(page));
browserContext.on('close', () => {
this._browserContextPromise = undefined;
this._closeBrowserContextPromise = undefined;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should resolve this promise in case there are waiting close calls?

});
if (this.config.saveTrace) {
await browserContext.tracing.start({
name: 'trace',
Expand Down
96 changes: 89 additions & 7 deletions src/mcp/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { zodToJsonSchema } from 'zod-to-json-schema';
import { ManualPromise } from '../manualPromise.js';
import { logUnhandledError } from '../log.js';

import type { ImageContent, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { ImageContent, Implementation, ListRootsResult, TextContent } from '@modelcontextprotocol/sdk/types.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
export type { Server } from '@modelcontextprotocol/sdk/server/index.js';

Expand All @@ -46,17 +46,80 @@ export type ToolSchema<Input extends z.Schema> = {

export type ToolHandler = (toolName: string, params: any) => Promise<ToolResponse>;

export interface InitializeInfo {
clientVersion: Implementation;
roots?: ListRootsResult;
}

export interface ServerBackend {
name: string;
version: string;
initialize?(server: Server): Promise<void>;
initialize?(info: InitializeInfo): Promise<void>;
tools(): ToolSchema<any>[];
callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse>;
serverClosed?(): void;

onChangeProxyTarget?: (target: string, options: any) => void;
}

export type ServerBackendFactory = () => ServerBackend;

export class ServerBackendSwitcher implements ServerBackend {
private _target: ServerBackend;
private _initializeInfo?: InitializeInfo;

constructor(private readonly _targetFactories: Record<string, (options: any) => ServerBackend>) {
const defaultTargetFactory = this._targetFactories[''];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use "default" key for the default factory?

this._target = defaultTargetFactory({});
this._target.onChangeProxyTarget = this._handleChangeProxyTarget;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this._handleChangeProxyTarget.bind(this) ?

}

_handleChangeProxyTarget = (name: string, options: any) => {
const factory = this._targetFactories[name];
if (!factory)
throw new Error(`Unknown target: ${name}`);
const target = factory(options);
this.switch(target).catch(logUnhandledError);
};

async switch(target: ServerBackend) {
const old = this._target;
old.onChangeProxyTarget = undefined;
this._target = target;
this._target.onChangeProxyTarget = this._handleChangeProxyTarget;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to assign onChangeProxyTarget to the target while we could just handle that particular tool call in this class ant spare all other backends from the knowledge about onChangeProxyTarget? Basically similar to how we switch between context factories but just intercept the backend switch call here.

if (this._initializeInfo) {
old.serverClosed?.();
await this.initialize(this._initializeInfo);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
await this.initialize(this._initializeInfo);
await this._target.initialize?.(this._initializeInfo);

}
}

get name() {
return this._target.name;
}

get version() {
return this._target.version;
}

async initialize(info: InitializeInfo): Promise<void> {
this._initializeInfo = info;
await this._target.initialize?.(info);
}

tools(): ToolSchema<any>[] {
return this._target.tools();
}

async callTool(schema: ToolSchema<any>, parsedArguments: any): Promise<ToolResponse> {
return this._target.callTool(schema, parsedArguments);
}

serverClosed(): void {
this._target.serverClosed?.();
this._initializeInfo = undefined;
}
}

export async function connect(serverBackendFactory: ServerBackendFactory, transport: Transport, runHeartbeat: boolean) {
const backend = serverBackendFactory();
const server = createServer(backend, runHeartbeat);
Expand All @@ -71,9 +134,9 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
}
});

const tools = backend.tools();
server.setRequestHandler(ListToolsRequestSchema, async () => {
return { tools: tools.map(tool => ({
await initializedPromise;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this deadlock with some clients?

return { tools: backend.tools().map(tool => ({
name: tool.name,
description: tool.description,
inputSchema: zodToJsonSchema(tool.inputSchema),
Expand All @@ -99,7 +162,7 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
content: [{ type: 'text', text: '### Result\n' + messages.join('\n') }],
isError: true,
});
const tool = tools.find(tool => tool.name === request.params.name) as ToolSchema<any>;
const tool = backend.tools().find(tool => tool.name === request.params.name) as ToolSchema<any>;
if (!tool)
return errorResult(`Error: Tool "${request.params.name}" not found`);

Expand All @@ -109,13 +172,32 @@ export function createServer(backend: ServerBackend, runHeartbeat: boolean): Ser
return errorResult(String(error));
}
});
addServerListener(server, 'initialized', () => {
backend.initialize?.(server).then(() => initializedPromise.resolve()).catch(logUnhandledError);
addServerListener(server, 'initialized', async () => {
try {
const info = await getInitializeInfo(server);
await backend.initialize?.(info);
initializedPromise.resolve();
} catch (e) {
logUnhandledError(e);
}
});
addServerListener(server, 'close', () => backend.serverClosed?.());
return server;
}

async function getInitializeInfo(server: Server) {
const info: InitializeInfo = {
clientVersion: server.getClientVersion()!,
};
if (server.getClientCapabilities()?.roots?.listRoots && isVSCode(info.clientVersion))
info.roots = await server.listRoots();
return info;
}

export function isVSCode(clientVersion: Implementation): boolean {
return clientVersion.name === 'Visual Studio Code' || clientVersion.name === 'Visual Studio Code - Insiders';
}

const startHeartbeat = (server: Server) => {
const beat = () => {
Promise.race([
Expand Down
7 changes: 6 additions & 1 deletion src/program.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import { BrowserServerBackend, FactoryList } from './browserServerBackend.js';
import { Context } from './context.js';
import { contextFactory } from './browserContextFactory.js';
import { runLoopTools } from './loopTools/main.js';
import { ServerBackendSwitcher } from './mcp/server.js';
import { VSCodeServerBackend } from './vscode/vscodeHost.js';

program
.version('Version ' + packageJSON.version)
Expand Down Expand Up @@ -82,7 +84,10 @@ program
const factories: FactoryList = [browserContextFactory];
if (options.connectTool)
factories.push(createExtensionContextFactory(config));
const serverBackendFactory = () => new BrowserServerBackend(config, factories);
const serverBackendFactory = () => new ServerBackendSwitcher({
'': () => new BrowserServerBackend(config, factories),
'vscode': options => new VSCodeServerBackend(config, options.connectionString, options.lib),
});
await mcpTransport.start(serverBackendFactory, config.server);

if (config.saveTrace) {
Expand Down
2 changes: 1 addition & 1 deletion src/tab.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export class Tab extends EventEmitter<TabEventsInterface> {
});
page.on('dialog', dialog => this._dialogShown(dialog));
page.on('download', download => {
void this._downloadStarted(download);
void this._downloadStarted(download).catch(logUnhandledError);
});
page.setDefaultNavigationTimeout(60000);
page.setDefaultTimeout(5000);
Expand Down
Loading
Loading