Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 174 additions & 59 deletions lib/ruby_llm/mcp/native/transports/sse.rb
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,32 @@ def initialize(url:, coordinator:, request_timeout:, options: {})
@pending_requests = {}
@pending_mutex = Mutex.new
@connection_mutex = Mutex.new
@state_mutex = Mutex.new
@running = false
@sse_thread = nil
@sse_response = nil

RubyLLM::MCP.logger.info "Initializing SSE transport to #{@event_url} with client ID #{@client_id}"
end

def request(body, add_id: true, wait_for_response: true)
def request(body, add_id: true, wait_for_response: true) # rubocop:disable Metrics/MethodLength
request_id = nil

if add_id
@id_mutex.synchronize { @id_counter += 1 }
request_id = @id_counter
body["id"] = request_id
elsif body.is_a?(Hash)
request_id = body["id"] || body[:id]
end

if wait_for_response && request_id.nil?
raise ArgumentError, "Request ID must be provided when wait_for_response is true and add_id is false"
end

response_queue = Queue.new
response_queue = nil
if wait_for_response
response_queue = Queue.new
@pending_mutex.synchronize do
@pending_requests[request_id.to_s] = response_queue
end
Expand All @@ -65,41 +76,83 @@ def request(body, add_id: true, wait_for_response: true)
begin
send_request(body, request_id)
rescue Errors::TransportError, Errors::TimeoutError => e
@pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) }
if wait_for_response && request_id
@pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) }
end
RubyLLM::MCP.logger.error "Request error (ID: #{request_id}): #{e.message}"
raise e
end

return unless wait_for_response

result = nil
begin
with_timeout(@request_timeout / 1000, request_id: request_id) do
result = with_timeout(@request_timeout / 1000, request_id: request_id) do
response_queue.pop
end
rescue Errors::TimeoutError => e
@pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) }
if request_id
@pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) }
end
RubyLLM::MCP.logger.error "SSE request timeout (ID: #{request_id}) \
after #{@request_timeout / 1000} seconds."
raise e
end

raise result if result.is_a?(Errors::TransportError)

result
end

def alive?
@running
running?
end

def running?
@state_mutex.synchronize { @running }
end

def start
return if @running
@state_mutex.synchronize do
return if @running

@running = true
end

@running = true
start_sse_listener
end

def close
should_close = @state_mutex.synchronize do
return unless @running

@running = false
true
end

return unless should_close

RubyLLM::MCP.logger.info "Closing SSE transport connection"
@running = false
@sse_thread&.join(1) # Give the thread a second to clean up

# Close the SSE response stream if it exists
begin
@sse_response&.body&.close
rescue StandardError => e
RubyLLM::MCP.logger.debug "Error closing SSE response: #{e.message}"
end

# Wait for the thread to finish
@sse_thread&.join(1)
@sse_thread = nil

fail_pending_requests!(
Errors::TransportError.new(
message: "SSE transport closed",
code: nil
)
)

@messages_url = nil
end

def set_protocol_version(version)
Expand Down Expand Up @@ -213,7 +266,7 @@ def attempt_authentication_retry(www_authenticate, resource_metadata_url, origin
end

def start_sse_listener
@connection_mutex.synchronize do
@connection_mutex.synchronize do # rubocop:disable Metrics/BlockLength
return if sse_thread_running?

RubyLLM::MCP.logger.info "Starting SSE listener thread"
Expand All @@ -224,44 +277,77 @@ def start_sse_listener
end

@sse_thread = Thread.new do
listen_for_events while @running
listen_for_events
end
@sse_thread.abort_on_exception = true

with_timeout(@request_timeout / 1000) do
endpoint = response_queue.pop
set_message_endpoint(endpoint)
begin
with_timeout(@request_timeout / 1000) do
endpoint = response_queue.pop
set_message_endpoint(endpoint)
end
rescue Errors::TimeoutError => e
@pending_mutex.synchronize do
@pending_requests.delete("endpoint")
end
RubyLLM::MCP.logger.error "Timeout waiting for endpoint event: #{e.message}"
raise e
rescue StandardError => e
@pending_mutex.synchronize do
@pending_requests.delete("endpoint")
end
raise e
end
end
end

def set_message_endpoint(endpoint)
uri = URI.parse(endpoint)
endpoint_url = if endpoint.is_a?(String)
endpoint
elsif endpoint.is_a?(Hash)
# Support richer endpoint metadata (e.g., { "url": "...", "last_event_id": "..." })
endpoint["url"] || endpoint[:url]
else
endpoint.to_s
end

unless endpoint_url && !endpoint_url.empty?
raise Errors::TransportError.new(
message: "Invalid endpoint event: missing URL",
code: nil
)
end

uri = URI.parse(endpoint_url)

@messages_url = if uri.host.nil?
"#{@root_url}#{endpoint}"
"#{@root_url}#{endpoint_url}"
else
endpoint
endpoint_url
end

RubyLLM::MCP.logger.info "SSE message endpoint set to: #{@messages_url}"
rescue URI::InvalidURIError => e
raise Errors::TransportError.new(
message: "Invalid endpoint URL: #{e.message}",
code: nil
)
end

def sse_thread_running?
@sse_thread&.alive?
end

def listen_for_events
stream_events_from_server
stream_events_from_server while running?
rescue StandardError => e
handle_connection_error("SSE connection error", e)
end

def stream_events_from_server
sse_client = create_sse_client
response = sse_client.get(@event_url, stream: true)
validate_sse_response!(response)
process_event_stream(response)
@sse_response = sse_client.get(@event_url, stream: true)
validate_sse_response!(@sse_response)
process_event_stream(@sse_response)
end

def create_sse_client
Expand Down Expand Up @@ -325,11 +411,22 @@ def handle_sse_authentication_challenge(response)
end

def handle_client_error!(error_message, status_code)
@running = false
raise Errors::TransportError.new(
transport_error = Errors::TransportError.new(
message: error_message,
code: status_code
)
close

raise transport_error
end

def fail_pending_requests!(error)
@pending_mutex.synchronize do
@pending_requests.each_value do |queue|
queue.push(error)
end
@pending_requests.clear
end
end

def process_event_stream(response)
Expand All @@ -340,7 +437,7 @@ def process_event_stream(response)
end

def handle_event_line?(event_line, event_buffer, response)
unless @running
unless running?
response.body.close
return false
end
Expand All @@ -365,7 +462,6 @@ def process_buffered_event(event_buffer)
end

def read_error_body(response)
# Try to read the error body from the response
body = ""
begin
response.each do |chunk|
Expand All @@ -378,11 +474,18 @@ def read_error_body(response)
end

def handle_connection_error(message, error)
return unless @running
return unless running?

error_message = "#{message}: #{error.message}"
RubyLLM::MCP.logger.error "#{error_message}. Reconnecting in 1 seconds..."
sleep 1
RubyLLM::MCP.logger.error "#{error_message}. Closing SSE transport."

transport_error = Errors::TransportError.new(
message: error_message,
code: nil
)
close

@coordinator&.handle_error(transport_error)
end

def handle_httpx_error_response!(response, context:)
Expand All @@ -405,44 +508,56 @@ def handle_httpx_error_response!(response, context:)
end

def process_event(raw_event)
# Return if we believe that are getting a partial event
return if raw_event[:data].nil?

if raw_event[:event] == "endpoint"
request_id = "endpoint"
event = raw_event[:data]
return if event.nil?

RubyLLM::MCP.logger.debug "Received endpoint event: #{event}"
@pending_mutex.synchronize do
response_queue = @pending_requests.delete(request_id)
response_queue&.push(event)
end
process_endpoint_event(raw_event)
else
event = begin
JSON.parse(raw_event[:data])
rescue JSON::ParserError => e
# We can sometimes get partial endpoint events, so we will ignore them
unless @endpoint.nil?
RubyLLM::MCP.logger.info "Failed to parse SSE event data: #{raw_event[:data]} - #{e.message}"
end
process_message_event(raw_event)
end
end

def process_endpoint_event(raw_event)
request_id = "endpoint"
event_data = raw_event[:data]
return if event_data.nil?

endpoint = begin
JSON.parse(event_data)
rescue JSON::ParserError
event_data
end

RubyLLM::MCP.logger.debug "Received endpoint event: #{endpoint.inspect}"

nil
@pending_mutex.synchronize do
response_queue = @pending_requests.delete(request_id)
response_queue&.push(endpoint)
end
end

def process_message_event(raw_event)
event = begin
JSON.parse(raw_event[:data])
rescue JSON::ParserError => e
if @messages_url
RubyLLM::MCP.logger.debug "Failed to parse SSE event data: #{raw_event[:data]} - #{e.message}"
end
return if event.nil?
nil
end
return if event.nil?

request_id = event["id"]&.to_s
result = RubyLLM::MCP::Result.new(event)
request_id = event["id"]&.to_s
result = RubyLLM::MCP::Result.new(event)

result = @coordinator.process_result(result)
return if result.nil?
result = @coordinator.process_result(result)
return if result.nil?

@pending_mutex.synchronize do
# You can receieve duplicate events for the same request id, and we will ignore thoses
if result.matching_id?(request_id) && @pending_requests.key?(request_id)
response_queue = @pending_requests.delete(request_id)
response_queue&.push(result)
end
@pending_mutex.synchronize do
# You can receive duplicate events for the same request id, and we will ignore those
if result.matching_id?(request_id) && @pending_requests.key?(request_id)
response_queue = @pending_requests.delete(request_id)
response_queue&.push(result)
end
end
end
Expand Down
Loading