Skip to content

Commit 43d388a

Browse files
authored
Merge commit from fork
Signed-off-by: Kent Rancourt <kent.rancourt@gmail.com>
1 parent a81c37e commit 43d388a

File tree

3 files changed

+177
-4
lines changed

3 files changed

+177
-4
lines changed

internal/promotion/runner/builtin/http_requester.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ import (
1717
kargoapi "github.com/akuity/kargo/api/v1alpha1"
1818
"github.com/akuity/kargo/internal/io"
1919
"github.com/akuity/kargo/internal/logging"
20+
kargonet "github.com/akuity/kargo/pkg/net"
2021
"github.com/akuity/kargo/pkg/promotion"
2122
"github.com/akuity/kargo/pkg/x/promotion/runner/builtin"
2223
)
2324

2425
const (
25-
contentTypeHeader = "Content-Type"
26-
contentTypeJSON = "application/json"
27-
maxResponseBytes = 2 << 20
26+
contentTypeHeader = "Content-Type"
27+
contentTypeJSON = "application/json"
28+
maxResponseBytes = 2 << 20
2829
requestTimeoutDefault = 10 * time.Second
2930
)
3031

@@ -83,6 +84,8 @@ func (h *httpRequester) run(
8384
return promotion.StepResult{Status: kargoapi.PromotionStepStatusErrored},
8485
fmt.Errorf("error creating HTTP client: %w", err)
8586
}
87+
// #nosec G704 -- The client is using a custom dialer that mitigates the worst
88+
// practical risks of SSRF by refusing to dial link-local addresses.
8689
resp, err := client.Do(req)
8790
if err != nil {
8891
return promotion.StepResult{Status: kargoapi.PromotionStepStatusErrored},
@@ -149,7 +152,7 @@ func (h *httpRequester) buildRequest(cfg builtin.HTTPConfig) (*http.Request, err
149152
}
150153

151154
func (h *httpRequester) getClient(cfg builtin.HTTPConfig) (*http.Client, error) {
152-
httpTransport := cleanhttp.DefaultTransport()
155+
httpTransport := kargonet.SafeTransport(cleanhttp.DefaultTransport())
153156
if cfg.InsecureSkipTLSVerify {
154157
httpTransport.TLSClientConfig = &tls.Config{
155158
InsecureSkipVerify: true, // nolint: gosec

pkg/net/safe_dialer.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package net
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"net/http"
8+
"time"
9+
)
10+
11+
var (
12+
linkLocalV4 = net.IPNet{
13+
IP: net.IP{169, 254, 0, 0},
14+
Mask: net.CIDRMask(16, 32),
15+
}
16+
linkLocalV6 = net.IPNet{
17+
IP: net.ParseIP("fe80::"),
18+
Mask: net.CIDRMask(10, 128),
19+
}
20+
)
21+
22+
// isLinkLocal returns true if the given IP is in the IPv4 link-local range
23+
// (169.254.0.0/16) or the IPv6 link-local range (fe80::/10).
24+
func isLinkLocal(ip net.IP) bool {
25+
return linkLocalV4.Contains(ip) || linkLocalV6.Contains(ip)
26+
}
27+
28+
// SafeDialContext returns a DialContext function that blocks connections to
29+
// link-local IP addresses (169.254.0.0/16 and fe80::/10). This prevents SSRF
30+
// attacks targeting cloud instance metadata endpoints (e.g. 169.254.169.254).
31+
//
32+
// The returned function resolves the hostname before connecting and rejects the
33+
// connection if all resolved addresses are link-local.
34+
func SafeDialContext(dialer *net.Dialer) func(
35+
ctx context.Context,
36+
network string,
37+
addr string,
38+
) (net.Conn, error) {
39+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
40+
host, port, err := net.SplitHostPort(addr)
41+
if err != nil {
42+
return nil, fmt.Errorf("failed to parse address %q: %w", addr, err)
43+
}
44+
45+
// Resolve the hostname to IP addresses.
46+
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
47+
if err != nil {
48+
return nil, fmt.Errorf("failed to resolve host %q: %w", host, err)
49+
}
50+
51+
// Filter out link-local addresses.
52+
var safe []net.IPAddr
53+
for _, ip := range ips {
54+
if !isLinkLocal(ip.IP) {
55+
safe = append(safe, ip)
56+
}
57+
}
58+
59+
if len(safe) == 0 {
60+
return nil, fmt.Errorf(
61+
"connections to link-local addresses are not permitted "+
62+
"(host %q resolved to link-local IPs only)",
63+
host,
64+
)
65+
}
66+
67+
// Dial using the first safe address.
68+
safeAddr := net.JoinHostPort(safe[0].IP.String(), port)
69+
return dialer.DialContext(ctx, network, safeAddr)
70+
}
71+
}
72+
73+
// SafeTransport wraps the given transport's DialContext to block connections to
74+
// link-local IP addresses.
75+
func SafeTransport(t *http.Transport) *http.Transport {
76+
dialer := &net.Dialer{
77+
Timeout: 30 * time.Second,
78+
KeepAlive: 30 * time.Second,
79+
}
80+
t.DialContext = SafeDialContext(dialer)
81+
return t
82+
}

pkg/net/safe_dialer_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package net
2+
3+
import (
4+
"net"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func Test_isLinkLocal(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
ip string
14+
expected bool
15+
}{
16+
{
17+
name: "IPv4 link-local lower bound",
18+
ip: "169.254.0.0",
19+
expected: true,
20+
},
21+
{
22+
name: "IPv4 link-local metadata endpoint",
23+
ip: "169.254.169.254",
24+
expected: true,
25+
},
26+
{
27+
name: "IPv4 link-local upper bound",
28+
ip: "169.254.255.255",
29+
expected: true,
30+
},
31+
{
32+
name: "IPv4 just below link-local range",
33+
ip: "169.253.255.255",
34+
expected: false,
35+
},
36+
{
37+
name: "IPv4 just above link-local range",
38+
ip: "169.255.0.0",
39+
expected: false,
40+
},
41+
{
42+
name: "IPv4 private 10.x",
43+
ip: "10.0.0.1",
44+
expected: false,
45+
},
46+
{
47+
name: "IPv4 public",
48+
ip: "8.8.8.8",
49+
expected: false,
50+
},
51+
{
52+
name: "IPv4 loopback",
53+
ip: "127.0.0.1",
54+
expected: false,
55+
},
56+
{
57+
name: "IPv6 link-local",
58+
ip: "fe80::1",
59+
expected: true,
60+
},
61+
{
62+
name: "IPv6 link-local upper bound",
63+
ip: "febf::ffff",
64+
expected: true,
65+
},
66+
{
67+
name: "IPv6 just outside link-local",
68+
ip: "fec0::1",
69+
expected: false,
70+
},
71+
{
72+
name: "IPv6 loopback",
73+
ip: "::1",
74+
expected: false,
75+
},
76+
{
77+
name: "IPv6 public",
78+
ip: "2001:db8::1",
79+
expected: false,
80+
},
81+
}
82+
for _, tt := range tests {
83+
t.Run(tt.name, func(t *testing.T) {
84+
ip := net.ParseIP(tt.ip)
85+
assert.Equal(t, tt.expected, isLinkLocal(ip))
86+
})
87+
}
88+
}

0 commit comments

Comments
 (0)