@@ -2,6 +2,8 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"
22import { StdioClientTransport , getDefaultEnvironment } from "@modelcontextprotocol/sdk/client/stdio.js"
33import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"
44import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"
5+ import { OAuthStreamableHTTPClientTransport } from "./OAuthStreamableHTTPClientTransport"
6+ import { OAuthConfig } from "./OAuthHandler"
57import ReconnectingEventSource from "reconnecting-eventsource"
68import {
79 CallToolResultSchema ,
@@ -117,6 +119,18 @@ const createServerTypeSchema = () => {
117119 type : z . enum ( [ "streamable-http" ] ) . optional ( ) ,
118120 url : z . string ( ) . url ( "URL must be a valid URL format" ) ,
119121 headers : z . record ( z . string ( ) ) . optional ( ) ,
122+ // OAuth configuration (optional)
123+ oauth : z
124+ . object ( {
125+ clientId : z . string ( ) ,
126+ clientSecret : z . string ( ) . optional ( ) ,
127+ authorizationUrl : z . string ( ) . url ( ) ,
128+ tokenUrl : z . string ( ) . url ( ) ,
129+ redirectUri : z . string ( ) . optional ( ) ,
130+ scopes : z . array ( z . string ( ) ) . optional ( ) ,
131+ additionalParams : z . record ( z . string ( ) ) . optional ( ) ,
132+ } )
133+ . optional ( ) ,
120134 // Ensure no stdio fields are present
121135 command : z . undefined ( ) . optional ( ) ,
122136 args : z . undefined ( ) . optional ( ) ,
@@ -736,30 +750,69 @@ export class McpHub {
736750 console . error ( `No stderr stream for ${ name } ` )
737751 }
738752 } else if ( configInjected . type === "streamable-http" ) {
739- // Streamable HTTP connection
740- transport = new StreamableHTTPClientTransport ( new URL ( configInjected . url ) , {
741- requestInit : {
753+ // Check if OAuth is configured
754+ if ( configInjected . oauth ) {
755+ // Use OAuth-enabled transport
756+ const provider = this . providerRef . deref ( )
757+ if ( ! provider ) {
758+ throw new Error ( "Provider not available for OAuth initialization" )
759+ }
760+
761+ const oauthTransport = new OAuthStreamableHTTPClientTransport ( {
762+ url : new URL ( configInjected . url ) ,
742763 headers : configInjected . headers ,
743- } ,
744- } )
764+ oauth : configInjected . oauth as OAuthConfig ,
765+ serverName : name ,
766+ context : provider . context ,
767+ } )
745768
746- // Set up Streamable HTTP specific error handling
747- transport . onerror = async ( error ) => {
748- console . error ( `Transport error for "${ name } " (streamable-http):` , error )
749- const connection = this . findConnection ( name , source )
750- if ( connection ) {
751- connection . server . status = "disconnected"
752- this . appendErrorMessage ( connection , error instanceof Error ? error . message : `${ error } ` )
769+ // Get the underlying transport
770+ transport = oauthTransport . getTransport ( )
771+
772+ // Set up error handling
773+ oauthTransport . onerror = async ( error ) => {
774+ console . error ( `Transport error for "${ name } " (streamable-http with OAuth):` , error )
775+ const connection = this . findConnection ( name , source )
776+ if ( connection ) {
777+ connection . server . status = "disconnected"
778+ this . appendErrorMessage ( connection , error instanceof Error ? error . message : `${ error } ` )
779+ }
780+ await this . notifyWebviewOfServerChanges ( )
753781 }
754- await this . notifyWebviewOfServerChanges ( )
755- }
756782
757- transport . onclose = async ( ) => {
758- const connection = this . findConnection ( name , source )
759- if ( connection ) {
760- connection . server . status = "disconnected"
783+ oauthTransport . onclose = async ( ) => {
784+ const connection = this . findConnection ( name , source )
785+ if ( connection ) {
786+ connection . server . status = "disconnected"
787+ }
788+ await this . notifyWebviewOfServerChanges ( )
789+ }
790+ } else {
791+ // Standard Streamable HTTP connection without OAuth
792+ transport = new StreamableHTTPClientTransport ( new URL ( configInjected . url ) , {
793+ requestInit : {
794+ headers : configInjected . headers ,
795+ } ,
796+ } )
797+
798+ // Set up Streamable HTTP specific error handling
799+ transport . onerror = async ( error ) => {
800+ console . error ( `Transport error for "${ name } " (streamable-http):` , error )
801+ const connection = this . findConnection ( name , source )
802+ if ( connection ) {
803+ connection . server . status = "disconnected"
804+ this . appendErrorMessage ( connection , error instanceof Error ? error . message : `${ error } ` )
805+ }
806+ await this . notifyWebviewOfServerChanges ( )
807+ }
808+
809+ transport . onclose = async ( ) => {
810+ const connection = this . findConnection ( name , source )
811+ if ( connection ) {
812+ connection . server . status = "disconnected"
813+ }
814+ await this . notifyWebviewOfServerChanges ( )
761815 }
762- await this . notifyWebviewOfServerChanges ( )
763816 }
764817 } else if ( configInjected . type === "sse" ) {
765818 // SSE connection
0 commit comments