Skip to content

Commit 2b3f76e

Browse files
committed
feat(web_fetch): enhance web fetch tool with DNS pinning and validation improvements
- Introduced a new validatedParams struct to hold validated input along with DNS-pinned host/IP for SSRF protection. - Updated the validateParams function to validate and resolve the host to a single public IP, ensuring safe outbound fetches. - Modified executeFetch and fetchHTMLContent methods to utilize the new validatedParams for improved security and clarity. - Enhanced logging to reflect the display URL and provide better context during fetch operations. This update strengthens the web fetch tool's security against SSRF attacks and improves the overall robustness of URL handling.
1 parent a5d1233 commit 2b3f76e

File tree

2 files changed

+112
-43
lines changed

2 files changed

+112
-43
lines changed

internal/agent/tools/web_fetch.go

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55
"encoding/json"
66
"fmt"
77
"io"
8+
"net"
89
"net/http"
10+
"net/url"
911
"regexp"
1012
"strings"
1113
"sync"
@@ -58,6 +60,16 @@ type webFetchParams struct {
5860
Prompt string
5961
}
6062

63+
// validatedParams holds validated input plus DNS-pinned host/IP for SSRF protection.
64+
// PinnedIP is the single IP we resolved at validation time; chromedp and HTTP both use it.
65+
type validatedParams struct {
66+
URL string
67+
Prompt string
68+
Host string
69+
Port string
70+
PinnedIP net.IP
71+
}
72+
6173
// webFetchItemResult is the result for a web fetch item
6274
type webFetchItemResult struct {
6375
output string
@@ -124,7 +136,10 @@ func (t *WebFetchTool) Execute(ctx context.Context, args json.RawMessage) (*type
124136
go func(index int, p webFetchParams) {
125137
defer wg.Done()
126138

127-
if err := t.validateParams(p); err != nil {
139+
// Normalize URL before validation so we pin the host we actually fetch (e.g. raw.githubusercontent.com)
140+
finalURL := t.normalizeGitHubURL(p.URL)
141+
vp, err := t.validateAndResolve(webFetchParams{URL: finalURL, Prompt: p.Prompt})
142+
if err != nil {
128143
results[index] = &webFetchItemResult{
129144
err: err,
130145
data: map[string]interface{}{
@@ -137,7 +152,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args json.RawMessage) (*type
137152
return
138153
}
139154

140-
output, data, err := t.executeFetch(ctx, p)
155+
output, data, err := t.executeFetch(ctx, vp, p.URL)
141156
results[index] = &webFetchItemResult{
142157
output: output,
143158
data: data,
@@ -234,60 +249,97 @@ func (t *WebFetchTool) parseParams(item interface{}) webFetchParams {
234249
return params
235250
}
236251

237-
// validateParams validates the parameters for a web fetch item
238-
func (t *WebFetchTool) validateParams(p webFetchParams) error {
252+
// validateAndResolve validates parameters and resolves the host to a single public IP (DNS pinning).
253+
// The returned PinnedIP is used for both chromedp (host-resolver-rules) and HTTP to prevent DNS rebinding.
254+
func (t *WebFetchTool) validateAndResolve(p webFetchParams) (*validatedParams, error) {
239255
if p.URL == "" {
240-
return fmt.Errorf("url is required")
256+
return nil, fmt.Errorf("url is required")
241257
}
242258
if p.Prompt == "" {
243-
return fmt.Errorf("prompt is required")
259+
return nil, fmt.Errorf("prompt is required")
244260
}
245261
if !strings.HasPrefix(p.URL, "http://") && !strings.HasPrefix(p.URL, "https://") {
246-
return fmt.Errorf("invalid URL format")
262+
return nil, fmt.Errorf("invalid URL format")
247263
}
248264

249-
// SSRF protection: validate URL is safe to fetch
265+
// SSRF protection: validate URL is safe (scheme, hostname, and that resolved IPs are not restricted)
250266
if safe, reason := utils.IsSSRFSafeURL(p.URL); !safe {
251-
return fmt.Errorf("URL rejected for security reasons: %s", reason)
267+
return nil, fmt.Errorf("URL rejected for security reasons: %s", reason)
268+
}
269+
270+
u, err := url.Parse(p.URL)
271+
if err != nil {
272+
return nil, fmt.Errorf("invalid URL: %w", err)
273+
}
274+
hostname := u.Hostname()
275+
port := u.Port()
276+
if port == "" {
277+
if u.Scheme == "https" {
278+
port = "443"
279+
} else {
280+
port = "80"
281+
}
252282
}
253283

254-
return nil
284+
// Resolve and pin to the first public IP (same resolver as IsSSRFSafeURL; we pin so chromedp cannot re-resolve)
285+
ips, err := net.DefaultResolver.LookupIP(context.Background(), "ip", hostname)
286+
if err != nil || len(ips) == 0 {
287+
return nil, fmt.Errorf("DNS lookup failed for %s: %w", hostname, err)
288+
}
289+
var pinnedIP net.IP
290+
for _, ip := range ips {
291+
if utils.IsPublicIP(ip) {
292+
pinnedIP = ip
293+
break
294+
}
295+
}
296+
if pinnedIP == nil {
297+
return nil, fmt.Errorf("no public IP available for host %s", hostname)
298+
}
299+
300+
return &validatedParams{
301+
URL: p.URL,
302+
Prompt: p.Prompt,
303+
Host: hostname,
304+
Port: port,
305+
PinnedIP: pinnedIP,
306+
}, nil
255307
}
256308

257-
// executeFetch executes a web fetch item
309+
// executeFetch executes a web fetch item. displayURL is the URL shown to the user (e.g. original); vp.URL is the normalized URL we fetch.
258310
func (t *WebFetchTool) executeFetch(
259311
ctx context.Context,
260-
params webFetchParams,
312+
vp *validatedParams,
313+
displayURL string,
261314
) (string, map[string]interface{}, error) {
262-
logger.Infof(ctx, "[Tool][WebFetch] Fetching URL: %s", params.URL)
263-
264-
finalURL := t.normalizeGitHubURL(params.URL)
315+
logger.Infof(ctx, "[Tool][WebFetch] Fetching URL: %s", displayURL)
265316

266-
htmlContent, method, err := t.fetchHTMLContent(ctx, finalURL)
317+
htmlContent, method, err := t.fetchHTMLContent(ctx, vp)
267318
if err != nil {
268-
logger.Errorf(ctx, "[Tool][WebFetch] 获取页面失败 url=%s err=%v", finalURL, err)
269-
return fmt.Sprintf("URL: %s\n错误: %v\n", params.URL, err),
319+
logger.Errorf(ctx, "[Tool][WebFetch] 获取页面失败 url=%s err=%v", vp.URL, err)
320+
return fmt.Sprintf("URL: %s\n错误: %v\n", displayURL, err),
270321
map[string]interface{}{
271-
"url": params.URL,
272-
"prompt": params.Prompt,
322+
"url": displayURL,
323+
"prompt": vp.Prompt,
273324
"error": err.Error(),
274325
}, err
275326
}
276327

277328
textContent := t.convertHTMLToText(htmlContent)
278329

279330
resultData := map[string]interface{}{
280-
"url": params.URL,
281-
"prompt": params.Prompt,
331+
"url": displayURL,
332+
"prompt": vp.Prompt,
282333
"raw_content": textContent,
283334
"content_length": len(textContent),
284335
"method": method,
285336
}
337+
params := webFetchParams{URL: displayURL, Prompt: vp.Prompt}
286338
var summary string
287339
var summaryErr error
288340
summary, summaryErr = t.processWithLLM(ctx, params, textContent)
289341
if summaryErr != nil {
290-
logger.Warnf(ctx, "[Tool][WebFetch] LLM 处理失败 url=%s err=%v", params.URL, summaryErr)
342+
logger.Warnf(ctx, "[Tool][WebFetch] LLM 处理失败 url=%s err=%v", displayURL, summaryErr)
291343
} else if summary != "" {
292344
resultData["summary"] = summary
293345
}
@@ -360,18 +412,18 @@ func (t *WebFetchTool) buildOutputText(params webFetchParams, content string, su
360412
return builder.String()
361413
}
362414

363-
// fetchHTMLContent fetches the HTML content for a web fetch item
364-
func (t *WebFetchTool) fetchHTMLContent(ctx context.Context, targetURL string) (string, string, error) {
365-
html, err := t.fetchWithChromedp(ctx, targetURL)
415+
// fetchHTMLContent fetches the HTML content for a web fetch item using pinned IP (DNS pinning).
416+
func (t *WebFetchTool) fetchHTMLContent(ctx context.Context, vp *validatedParams) (string, string, error) {
417+
html, err := t.fetchWithChromedp(ctx, vp)
366418
if err == nil && strings.TrimSpace(html) != "" {
367419
return html, "chromedp", nil
368420
}
369421

370422
if err != nil {
371-
logger.Debugf(ctx, "[Tool][WebFetch] Chromedp 抓取失败 url=%s err=%v,尝试直接请求", targetURL, err)
423+
logger.Debugf(ctx, "[Tool][WebFetch] Chromedp 抓取失败 url=%s err=%v,尝试直接请求", vp.URL, err)
372424
}
373425

374-
html, httpErr := t.fetchWithHTTP(ctx, targetURL)
426+
html, httpErr := t.fetchWithHTTP(ctx, vp)
375427
if httpErr != nil {
376428
if err != nil {
377429
return "", "", fmt.Errorf("chromedp error: %v; http error: %w", err, httpErr)
@@ -382,12 +434,15 @@ func (t *WebFetchTool) fetchHTMLContent(ctx context.Context, targetURL string) (
382434
return html, "http", nil
383435
}
384436

385-
// fetchWithChromedp fetches the HTML content with Chromedp
386-
func (t *WebFetchTool) fetchWithChromedp(ctx context.Context, targetURL string) (string, error) {
387-
logger.Debugf(ctx, "[Tool][WebFetch] Chromedp 抓取开始 url=%s", targetURL)
437+
// fetchWithChromedp fetches the HTML content with Chromedp. Uses host-resolver-rules to pin host to vp.PinnedIP (DNS rebinding protection).
438+
func (t *WebFetchTool) fetchWithChromedp(ctx context.Context, vp *validatedParams) (string, error) {
439+
logger.Debugf(ctx, "[Tool][WebFetch] Chromedp 抓取开始 url=%s", vp.URL)
388440

441+
// DNS pinning: force Chrome to use the IP we resolved at validation time, not a second resolution.
442+
hostRule := fmt.Sprintf("MAP %s %s", vp.Host, vp.PinnedIP.String())
389443
opts := append(
390444
chromedp.DefaultExecAllocatorOptions[:],
445+
chromedp.Flag("host-resolver-rules", hostRule),
391446
chromedp.Flag("headless", true),
392447
chromedp.Flag("disable-setuid-sandbox", true),
393448
chromedp.Flag("disable-dev-shm-usage", true),
@@ -410,21 +465,21 @@ func (t *WebFetchTool) fetchWithChromedp(ctx context.Context, targetURL string)
410465

411466
var html string
412467
err := chromedp.Run(ctx,
413-
chromedp.Navigate(targetURL),
468+
chromedp.Navigate(vp.URL),
414469
chromedp.WaitReady("body", chromedp.ByQuery),
415470
chromedp.OuterHTML("html", &html),
416471
)
417472
if err != nil {
418473
return "", fmt.Errorf("chromedp run failed: %w", err)
419474
}
420475

421-
logger.Debugf(ctx, "[Tool][WebFetch] Chromedp 抓取成功 url=%s", targetURL)
476+
logger.Debugf(ctx, "[Tool][WebFetch] Chromedp 抓取成功 url=%s", vp.URL)
422477
return html, nil
423478
}
424479

425-
// fetchWithHTTP fetches the HTML content with HTTP
426-
func (t *WebFetchTool) fetchWithHTTP(ctx context.Context, targetURL string) (string, error) {
427-
resp, err := t.fetchWithTimeout(ctx, targetURL)
480+
// fetchWithHTTP fetches the HTML content with HTTP using pinned IP (same as chromedp path).
481+
func (t *WebFetchTool) fetchWithHTTP(ctx context.Context, vp *validatedParams) (string, error) {
482+
resp, err := t.fetchWithTimeout(ctx, vp)
428483
if err != nil {
429484
return "", err
430485
}
@@ -443,12 +498,19 @@ func (t *WebFetchTool) fetchWithHTTP(ctx context.Context, targetURL string) (str
443498
return string(htmlBytes), nil
444499
}
445500

446-
// fetchWithTimeout fetches the HTML content with a timeout
447-
func (t *WebFetchTool) fetchWithTimeout(ctx context.Context, targetURL string) (*http.Response, error) {
448-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
501+
// fetchWithTimeout fetches the HTML content with a timeout. Uses pinned IP and original Host header (DNS pinning).
502+
func (t *WebFetchTool) fetchWithTimeout(ctx context.Context, vp *validatedParams) (*http.Response, error) {
503+
// Connect to pinned IP so we do not re-resolve; set Host so the server gets the right virtual host.
504+
hostPort := net.JoinHostPort(vp.PinnedIP.String(), vp.Port)
505+
rawURL := vp.URL
506+
u, _ := url.Parse(rawURL)
507+
u.Host = hostPort
508+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
449509
if err != nil {
450510
return nil, fmt.Errorf("failed to create request: %w", err)
451511
}
512+
// Preserve original host for TLS SNI and Host header (required for virtual hosting).
513+
req.Host = net.JoinHostPort(vp.Host, vp.Port)
452514

453515
req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; WebFetchTool/1.0)")
454516
req.Header.Set(

internal/utils/security.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,13 @@ func isRestrictedIP(ip net.IP) (bool, string) {
244244
return false, ""
245245
}
246246

247+
// IsPublicIP returns true if the IP is safe for outbound fetch (not private, loopback, link-local, etc.).
248+
// Used for DNS pinning: after resolving a hostname we pick the first public IP and pin all requests to it.
249+
func IsPublicIP(ip net.IP) bool {
250+
restricted, _ := isRestrictedIP(ip)
251+
return !restricted
252+
}
253+
247254
// isZeros checks if a byte slice is all zeros
248255
func isZeros(b []byte) bool {
249256
for _, v := range b {
@@ -652,10 +659,10 @@ func ValidateStdioConfig(command string, args []string, envVars map[string]strin
652659

653660
// SSRFSafeHTTPClientConfig contains configuration for the SSRF-safe HTTP client
654661
type SSRFSafeHTTPClientConfig struct {
655-
Timeout time.Duration
656-
MaxRedirects int
657-
DisableKeepAlives bool
658-
DisableCompression bool
662+
Timeout time.Duration
663+
MaxRedirects int
664+
DisableKeepAlives bool
665+
DisableCompression bool
659666
}
660667

661668
// DefaultSSRFSafeHTTPClientConfig returns the default configuration

0 commit comments

Comments
 (0)