Skip to content

Commit 498341a

Browse files
authored
feat: Add Host header override support for manual DNS resolution (#674)
* feat: Add Host header override support for manual DNS resolution * feat: Add Host header override support for manual DNS resolution * feat: Add Host header override support for manual DNS resolution
1 parent 1e5bacc commit 498341a

File tree

5 files changed

+336
-0
lines changed

5 files changed

+336
-0
lines changed

client/sse.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ func WithHTTPClient(httpClient *http.Client) transport.ClientOption {
2020
return transport.WithHTTPClient(httpClient)
2121
}
2222

23+
// WithHTTPHost sets a custom Host header for the SSE client, enabling manual DNS resolution.
24+
// This allows connecting to an IP address while sending a specific Host header to the server.
25+
// For example, connecting to "http://192.168.1.100:8080/sse" but sending Host: "api.example.com"
26+
func WithHTTPHost(host string) transport.ClientOption {
27+
return transport.WithHTTPHost(host)
28+
}
29+
2330
// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
2431
// Returns an error if the URL is invalid.
2532
func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) {

client/transport/sse.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type SSE struct {
3434
endpointChan chan struct{}
3535
headers map[string]string
3636
headerFunc HTTPHeaderFunc
37+
host string
3738
logger util.Logger
3839

3940
started atomic.Bool
@@ -80,6 +81,15 @@ func WithOAuth(config OAuthConfig) ClientOption {
8081
}
8182
}
8283

84+
// WithHTTPHost sets a custom Host header for the SSE client, enabling manual DNS resolution.
85+
// This allows connecting to an IP address while sending a specific Host header to the server.
86+
// For example, connecting to "http://192.168.1.100:8080/sse" but sending Host: "api.example.com"
87+
func WithHTTPHost(host string) ClientOption {
88+
return func(sc *SSE) {
89+
sc.host = host
90+
}
91+
}
92+
8393
// NewSSE creates a new SSE-based MCP client with the given base URL.
8494
// Returns an error if the URL is invalid.
8595
func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
@@ -126,6 +136,11 @@ func (c *SSE) Start(ctx context.Context) error {
126136
return fmt.Errorf("failed to create request: %w", err)
127137
}
128138

139+
// Set custom Host header if provided
140+
if c.host != "" {
141+
req.Host = c.host
142+
}
143+
129144
req.Header.Set("Accept", "text/event-stream")
130145
req.Header.Set("Cache-Control", "no-cache")
131146
req.Header.Set("Connection", "keep-alive")
@@ -387,6 +402,11 @@ func (c *SSE) SendRequest(
387402
}
388403
}
389404

405+
// Set custom Host header if provided
406+
if c.host != "" {
407+
req.Host = c.host
408+
}
409+
390410
// Add OAuth authorization if configured
391411
if c.oauthHandler != nil {
392412
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
@@ -578,6 +598,11 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
578598
}
579599
}
580600

601+
// Set custom Host header if provided
602+
if c.host != "" {
603+
req.Host = c.host
604+
}
605+
581606
resp, err := c.httpClient.Do(req)
582607
if err != nil {
583608
return fmt.Errorf("failed to send notification: %w", err)

client/transport/sse_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"io"
99
"net/http"
1010
"net/http/httptest"
11+
"net/url"
1112
"strings"
1213
"sync"
1314
"testing"
@@ -1078,6 +1079,146 @@ func TestSSE_SendNotification_Unauthorized_StaticToken(t *testing.T) {
10781079
transport.Close()
10791080
}
10801081

1082+
// TestSSEHostOverride tests the Host header override functionality
1083+
func TestSSEHostOverride(t *testing.T) {
1084+
// Create a test server that captures the Host header
1085+
var capturedHost string
1086+
var mu sync.Mutex
1087+
1088+
sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1089+
mu.Lock()
1090+
capturedHost = r.Host
1091+
mu.Unlock()
1092+
1093+
w.Header().Set("Content-Type", "text/event-stream")
1094+
flusher, ok := w.(http.Flusher)
1095+
if !ok {
1096+
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
1097+
return
1098+
}
1099+
1100+
// Send initial endpoint event
1101+
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", "/message")
1102+
flusher.Flush()
1103+
1104+
// Keep connection open
1105+
<-r.Context().Done()
1106+
})
1107+
1108+
messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1109+
if r.Method != http.MethodPost {
1110+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
1111+
return
1112+
}
1113+
1114+
response := JSONRPCResponse{
1115+
JSONRPC: "2.0",
1116+
ID: mcp.NewRequestId(1),
1117+
Result: []byte("test"),
1118+
}
1119+
1120+
w.Header().Set("Content-Type", "application/json")
1121+
if err := json.NewEncoder(w).Encode(response); err != nil {
1122+
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
1123+
return
1124+
}
1125+
})
1126+
1127+
mux := http.NewServeMux()
1128+
mux.Handle("/", sseHandler)
1129+
mux.Handle("/message", messageHandler)
1130+
1131+
testServer := httptest.NewServer(mux)
1132+
defer testServer.Close()
1133+
1134+
// Parse test server URL to get the actual host
1135+
serverURL, _ := url.Parse(testServer.URL)
1136+
actualHost := serverURL.Host
1137+
1138+
t.Run("Default Host (no override)", func(t *testing.T) {
1139+
capturedHost = ""
1140+
trans, err := NewSSE(testServer.URL)
1141+
require.NoError(t, err)
1142+
1143+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1144+
defer cancel()
1145+
1146+
err = trans.Start(ctx)
1147+
require.NoError(t, err)
1148+
defer trans.Close()
1149+
1150+
// Host should match the actual server host
1151+
mu.Lock()
1152+
require.Equal(t, actualHost, capturedHost)
1153+
mu.Unlock()
1154+
})
1155+
1156+
t.Run("Custom Host override", func(t *testing.T) {
1157+
capturedHost = ""
1158+
customHost := "api.example.com"
1159+
1160+
trans, err := NewSSE(testServer.URL, WithHTTPHost(customHost))
1161+
require.NoError(t, err)
1162+
1163+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1164+
defer cancel()
1165+
1166+
err = trans.Start(ctx)
1167+
require.NoError(t, err)
1168+
defer trans.Close()
1169+
1170+
// Host should be the custom host, not the actual server host
1171+
mu.Lock()
1172+
require.Equal(t, customHost, capturedHost)
1173+
require.NotEqual(t, actualHost, capturedHost)
1174+
mu.Unlock()
1175+
})
1176+
1177+
t.Run("Custom Host with port", func(t *testing.T) {
1178+
capturedHost = ""
1179+
customHost := "backend.internal.com:8443"
1180+
1181+
trans, err := NewSSE(testServer.URL, WithHTTPHost(customHost))
1182+
require.NoError(t, err)
1183+
1184+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1185+
defer cancel()
1186+
1187+
err = trans.Start(ctx)
1188+
require.NoError(t, err)
1189+
defer trans.Close()
1190+
1191+
// Host should be the custom host with port
1192+
mu.Lock()
1193+
require.Equal(t, customHost, capturedHost)
1194+
mu.Unlock()
1195+
})
1196+
1197+
// Test WithHTTPHost function directly (unit test)
1198+
t.Run("WithHTTPHost function", func(t *testing.T) {
1199+
sse := &SSE{}
1200+
customHost := "test.example.com"
1201+
1202+
option := WithHTTPHost(customHost)
1203+
option(sse)
1204+
1205+
require.Equal(t, customHost, sse.host)
1206+
1207+
// Test overwrite
1208+
newHost := "new.example.com"
1209+
option = WithHTTPHost(newHost)
1210+
option(sse)
1211+
1212+
require.Equal(t, newHost, sse.host)
1213+
1214+
// Test empty string
1215+
option = WithHTTPHost("")
1216+
option(sse)
1217+
1218+
require.Equal(t, "", sse.host)
1219+
})
1220+
}
1221+
10811222
func TestSSE_SendRequest_Timeout(t *testing.T) {
10821223
t.Run("TimeoutWhenServerNeverResponds", func(t *testing.T) {
10831224
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

client/transport/streamable_http.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ func WithSession(sessionID string) StreamableHTTPCOption {
8787
}
8888
}
8989

90+
// WithStreamableHTTPHost sets a custom Host header for the StreamableHTTP client, enabling manual DNS resolution.
91+
// This allows connecting to an IP address while sending a specific Host header to the server.
92+
// For example, connecting to "http://192.168.1.100:8080/mcp" but sending Host: "api.example.com"
93+
func WithStreamableHTTPHost(host string) StreamableHTTPCOption {
94+
return func(sc *StreamableHTTP) {
95+
sc.host = host
96+
}
97+
}
98+
9099
// StreamableHTTP implements Streamable HTTP transport.
91100
//
92101
// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
@@ -103,6 +112,7 @@ type StreamableHTTP struct {
103112
httpClient *http.Client
104113
headers map[string]string
105114
headerFunc HTTPHeaderFunc
115+
host string
106116
logger util.Logger
107117
getListeningEnabled bool
108118

@@ -217,6 +227,11 @@ func (c *StreamableHTTP) Close() error {
217227
req.Header.Set(HeaderKeyProtocolVersion, version)
218228
}
219229
}
230+
231+
// Set custom Host header if provided
232+
if c.host != "" {
233+
req.Host = c.host
234+
}
220235
res, err := c.httpClient.Do(req)
221236
if err != nil {
222237
c.logger.Errorf("failed to send close request: %v", err)
@@ -379,6 +394,11 @@ func (c *StreamableHTTP) sendHTTP(
379394
req.Header.Set(k, v)
380395
}
381396

397+
// Set custom Host header if provided
398+
if c.host != "" {
399+
req.Host = c.host
400+
}
401+
382402
// Add OAuth authorization if configured
383403
if c.oauthHandler != nil {
384404
authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)

0 commit comments

Comments
 (0)