@@ -16,6 +16,7 @@ import io.ktor.http.protocolWithAuthority
1616import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
1717import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
1818import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
19+ import kotlinx.coroutines.CancellationException
1920import kotlinx.coroutines.CompletableDeferred
2021import kotlinx.coroutines.CoroutineName
2122import kotlinx.coroutines.CoroutineScope
@@ -24,10 +25,11 @@ import kotlinx.coroutines.Job
2425import kotlinx.coroutines.SupervisorJob
2526import kotlinx.coroutines.cancel
2627import kotlinx.coroutines.cancelAndJoin
28+ import kotlinx.coroutines.ensureActive
2729import kotlinx.coroutines.launch
30+ import kotlinx.serialization.SerializationException
2831import kotlin.concurrent.atomics.AtomicBoolean
2932import kotlin.concurrent.atomics.ExperimentalAtomicApi
30- import kotlin.properties.Delegates
3133import kotlin.time.Duration
3234
3335@Deprecated(" Use SseClientTransport instead" , ReplaceWith (" SseClientTransport" ), DeprecationLevel .WARNING )
@@ -44,97 +46,59 @@ public class SseClientTransport(
4446 private val reconnectionTime : Duration ? = null ,
4547 private val requestBuilder : HttpRequestBuilder .() -> Unit = {},
4648) : AbstractTransport() {
47- private val scope by lazy {
48- CoroutineScope (session.coroutineContext + SupervisorJob ())
49- }
50-
5149 private val initialized: AtomicBoolean = AtomicBoolean (false )
52- private var session: ClientSSESession by Delegates .notNull()
5350 private val endpoint = CompletableDeferred <String >()
5451
52+ private lateinit var session: ClientSSESession
53+ private lateinit var scope: CoroutineScope
5554 private var job: Job ? = null
5655
57- private val baseUrl by lazy {
58- val requestUrl = session.call.request.url.toString()
59- val url = Url (requestUrl)
60- var path = url.encodedPath
61- if (path.isEmpty()) {
62- url.protocolWithAuthority
63- } else if (path.endsWith(" /" )) {
64- url.protocolWithAuthority + path.removeSuffix(" /" )
65- } else {
66- // the last item is not a directory, so will not be taken into account
67- path = path.substring(0 , path.lastIndexOf(" /" ))
68- url.protocolWithAuthority + path
56+ private val baseUrl: String by lazy {
57+ session.call.request.url.let { url ->
58+ val path = url.encodedPath
59+ when {
60+ path.isEmpty() -> url.protocolWithAuthority
61+ path.endsWith(" /" ) -> url.protocolWithAuthority + path.removeSuffix(" /" )
62+ else -> url.protocolWithAuthority + path.take(path.lastIndexOf(" /" ))
63+ }
6964 }
7065 }
7166
7267 override suspend fun start () {
73- if (! initialized.compareAndSet(expectedValue = false , newValue = true )) {
74- error(
75- " SSEClientTransport already started! " +
76- " If using Client class, note that connect() calls start() automatically." ,
77- )
68+ check(initialized.compareAndSet(expectedValue = false , newValue = true )) {
69+ " SSEClientTransport already started! If using Client class, note that connect() calls start() automatically."
7870 }
7971
80- session = urlString?.let {
81- client.sseSession(
82- urlString = it,
72+ try {
73+ session = urlString?.let {
74+ client.sseSession(
75+ urlString = it,
76+ reconnectionTime = reconnectionTime,
77+ block = requestBuilder,
78+ )
79+ } ? : client.sseSession(
8380 reconnectionTime = reconnectionTime,
8481 block = requestBuilder,
8582 )
86- } ? : client.sseSession(
87- reconnectionTime = reconnectionTime,
88- block = requestBuilder,
89- )
90-
91- job = scope.launch(CoroutineName (" SseMcpClientTransport.collect#${hashCode()} " )) {
92- session.incoming.collect { event ->
93- when (event.event) {
94- " error" -> {
95- val e = IllegalStateException (" SSE error: ${event.data} " )
96- _onError (e)
97- throw e
98- }
99-
100- " open" -> {
101- // The connection is open, but we need to wait for the endpoint to be received.
102- }
103-
104- " endpoint" -> {
105- try {
106- val eventData = event.data ? : " "
107-
108- // check url correctness
109- val maybeEndpoint = Url (" $baseUrl /${if (eventData.startsWith(" /" )) eventData.substring(1 ) else eventData} " )
110- endpoint.complete(maybeEndpoint.toString())
111- } catch (e: Exception ) {
112- _onError (e)
113- close()
114- error(e)
115- }
116- }
83+ scope = CoroutineScope (session.coroutineContext + SupervisorJob ())
11784
118- else -> {
119- try {
120- val message = McpJson .decodeFromString<JSONRPCMessage >(event.data ? : " " )
121- _onMessage (message)
122- } catch (e: Exception ) {
123- _onError (e)
124- }
125- }
126- }
85+ job = scope.launch(CoroutineName (" SseMcpClientTransport.connect#${hashCode()} " )) {
86+ collectMessages()
12787 }
128- }
12988
130- endpoint.await()
89+ endpoint.await()
90+ } catch (e: Exception ) {
91+ closeResources()
92+ initialized.store(false )
93+ throw e
94+ }
13195 }
13296
13397 @OptIn(ExperimentalCoroutinesApi ::class )
13498 override suspend fun send (message : JSONRPCMessage ) {
135- if ( ! endpoint.isCompleted) {
136- error( " Not connected " )
137- }
99+ check(initialized.load()) { " SseClientTransport is not initialized! " }
100+ check(job?.isActive == true ) { " SseClientTransport is closed! " }
101+ check(endpoint.isCompleted) { " Not connected! " }
138102
139103 try {
140104 val response = client.post(endpoint.getCompleted()) {
@@ -147,19 +111,80 @@ public class SseClientTransport(
147111 val text = response.bodyAsText()
148112 error(" Error POSTing to endpoint (HTTP ${response.status} ): $text " )
149113 }
150- } catch (e: Exception ) {
114+ } catch (e: Throwable ) {
151115 _onError (e)
152116 throw e
153117 }
154118 }
155119
156120 override suspend fun close () {
157- if (! initialized.load()) {
158- error(" SSEClientTransport is not initialized!" )
121+ check(initialized.load()) { " SseClientTransport is not initialized!" }
122+ closeResources()
123+ }
124+
125+ private suspend fun CoroutineScope.collectMessages () {
126+ try {
127+ session.incoming.collect { event ->
128+ ensureActive()
129+
130+ when (event.event) {
131+ " error" -> {
132+ val error = IllegalStateException (" SSE error: ${event.data} " )
133+ _onError (error)
134+ throw error
135+ }
136+
137+ " open" -> {
138+ // The connection is open, but we need to wait for the endpoint to be received.
139+ }
140+
141+ " endpoint" -> handleEndpoint(event.data.orEmpty())
142+ else -> handleMessage(event.data.orEmpty())
143+ }
144+ }
145+ } catch (e: CancellationException ) {
146+ throw e
147+ } catch (e: Throwable ) {
148+ _onError (e)
149+ throw e
150+ } finally {
151+ closeResources()
159152 }
153+ }
154+
155+ private fun handleEndpoint (eventData : String ) {
156+ try {
157+ val path = if (eventData.startsWith(" /" )) eventData.substring(1 ) else eventData
158+ val endpointUrl = Url (" $baseUrl /$path " )
159+ endpoint.complete(endpointUrl.toString())
160+ } catch (e: Throwable ) {
161+ _onError (e)
162+ endpoint.completeExceptionally(e)
163+ throw e
164+ }
165+ }
166+
167+ private suspend fun handleMessage (data : String ) {
168+ try {
169+ val message = McpJson .decodeFromString<JSONRPCMessage >(data)
170+ _onMessage (message)
171+ } catch (e: SerializationException ) {
172+ _onError (e)
173+ }
174+ }
175+
176+ private suspend fun closeResources () {
177+ if (! initialized.compareAndSet(expectedValue = true , newValue = false )) return
160178
161- session.cancel()
162- _onClose ()
163179 job?.cancelAndJoin()
180+ try {
181+ if (::session.isInitialized) session.cancel()
182+ if (::scope.isInitialized) scope.cancel()
183+ endpoint.cancel()
184+ } catch (e: Throwable ) {
185+ _onError (e)
186+ }
187+
188+ _onClose ()
164189 }
165190}
0 commit comments