Skip to content

Commit ccd3f4b

Browse files
authored
Merge pull request #782 from docker/AIMI-4
Validate realm URL before token exchange
2 parents 1047b07 + 46145d1 commit ccd3f4b

File tree

2 files changed

+274
-3
lines changed

2 files changed

+274
-3
lines changed

pkg/distribution/oci/remote/transport.go

Lines changed: 166 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ package remote
22

33
import (
44
"context"
5+
"crypto/tls"
56
"encoding/json"
67
"fmt"
78
"io"
9+
"log/slog"
10+
"net"
811
"net/http"
912
"net/url"
1013
"strings"
@@ -38,6 +41,153 @@ type Token struct {
3841
ExpiresIn int `json:"expires_in"`
3942
}
4043

44+
// privateOrLoopbackCIDRs lists IP ranges that must never be contacted as a
45+
// token-exchange realm. Allowing requests to these addresses would let a
46+
// malicious registry pivot Model Runner into an internal-service proxy
47+
// (SSRF), reaching endpoints that are not accessible from the public internet.
48+
var privateOrLoopbackCIDRs = func() []*net.IPNet {
49+
cidrs := []string{
50+
"127.0.0.0/8", // loopback IPv4
51+
"::1/128", // loopback IPv6
52+
"169.254.0.0/16", // link-local IPv4 / AWS EC2 instance-metadata
53+
"fe80::/10", // link-local IPv6
54+
"10.0.0.0/8", // RFC-1918 private
55+
"172.16.0.0/12", // RFC-1918 private
56+
"192.168.0.0/16", // RFC-1918 private
57+
"fc00::/7", // IPv6 ULA
58+
}
59+
nets := make([]*net.IPNet, 0, len(cidrs))
60+
for _, cidr := range cidrs {
61+
_, n, err := net.ParseCIDR(cidr)
62+
if err != nil {
63+
// These are hardcoded compile-time constants; a parse failure
64+
// indicates a programmer error (e.g. a typo). Panic immediately
65+
// so the mistake is caught at startup rather than silently
66+
// weakening the SSRF blocklist.
67+
panic(fmt.Sprintf("internal error: failed to parse hardcoded CIDR %q: %v", cidr, err))
68+
}
69+
nets = append(nets, n)
70+
}
71+
return nets
72+
}()
73+
74+
// internalHostnames lists hostnames that must never be used as a realm,
75+
// regardless of what IP address they resolve to.
76+
var internalHostnames = []string{
77+
"localhost",
78+
"host.docker.internal",
79+
"model-runner.docker.internal",
80+
"gateway.docker.internal",
81+
}
82+
83+
// isDisallowedIP reports whether ip falls in any of the private/loopback/
84+
// link-local ranges that must not be contacted as a token-exchange realm.
85+
func isDisallowedIP(ip net.IP) bool {
86+
for _, cidr := range privateOrLoopbackCIDRs {
87+
if cidr.Contains(ip) {
88+
return true
89+
}
90+
}
91+
return false
92+
}
93+
94+
// resolveAndValidateRealm parses the realm URL, validates the hostname against
95+
// a static blocklist, and resolves it to a dial address (ip:port) that is safe
96+
// to connect to. By returning the resolved IP, callers can use a custom
97+
// DialContext to connect to that exact address — preventing DNS-rebinding
98+
// attacks where a malicious DNS server could return different IPs for
99+
// successive lookups (TOCTOU).
100+
//
101+
// If hostname is a literal IP address it is validated directly without
102+
// triggering a DNS lookup.
103+
func resolveAndValidateRealm(rawURL string) (dialAddr, hostname string, err error) {
104+
u, err := url.Parse(rawURL)
105+
if err != nil {
106+
return "", "", fmt.Errorf("invalid realm URL: %w", err)
107+
}
108+
109+
hostname = u.Hostname()
110+
port := u.Port()
111+
if port == "" {
112+
switch u.Scheme {
113+
case "https":
114+
port = "443"
115+
default:
116+
port = "80"
117+
}
118+
}
119+
120+
// Block well-known internal hostnames regardless of DNS resolution.
121+
for _, internal := range internalHostnames {
122+
if strings.EqualFold(hostname, internal) {
123+
return "", "", fmt.Errorf("realm URL hostname %q is not allowed", hostname)
124+
}
125+
}
126+
127+
// If the hostname is a literal IP address, validate it directly without a
128+
// DNS lookup — there is no DNS to rebind.
129+
if ip := net.ParseIP(hostname); ip != nil {
130+
if isDisallowedIP(ip) {
131+
return "", "", fmt.Errorf("realm URL contains a disallowed IP address %s", hostname)
132+
}
133+
return net.JoinHostPort(hostname, port), hostname, nil
134+
}
135+
136+
// Resolve the hostname and validate every returned address. Using the
137+
// resolved IP as the dial address prevents DNS rebinding: the same IP that
138+
// passed validation is the one that will be used for the connection.
139+
ips, err := net.LookupHost(hostname)
140+
if err != nil {
141+
return "", "", fmt.Errorf("resolving realm hostname %q: %w", hostname, err)
142+
}
143+
if len(ips) == 0 {
144+
return "", "", fmt.Errorf("realm hostname %q resolved to no addresses", hostname)
145+
}
146+
for _, ipStr := range ips {
147+
ip := net.ParseIP(ipStr)
148+
if ip == nil {
149+
continue
150+
}
151+
if isDisallowedIP(ip) {
152+
return "", "", fmt.Errorf("realm URL resolves to a disallowed address %s", ipStr)
153+
}
154+
}
155+
156+
// All resolved IPs passed validation. Use the first one as the dial
157+
// address so the HTTP client never performs a second DNS lookup.
158+
return net.JoinHostPort(ips[0], port), hostname, nil
159+
}
160+
161+
// buildSafeTransport wraps base with a custom DialContext that connects
162+
// directly to dialAddr (a pre-validated "ip:port") instead of relying on DNS
163+
// resolution. This closes the TOCTOU window between realm URL validation and
164+
// the actual connection. For TLS connections, serverName is set as the SNI
165+
// value so that certificate validation still uses the original hostname.
166+
func buildSafeTransport(base http.RoundTripper, dialAddr, serverName string) (http.RoundTripper, error) {
167+
if base == nil {
168+
base = http.DefaultTransport
169+
}
170+
171+
t, ok := base.(*http.Transport)
172+
if !ok {
173+
return nil, fmt.Errorf("cannot build safe transport from base of type %T; only *http.Transport is supported", base)
174+
}
175+
176+
dial := func(ctx context.Context, network, _ string) (net.Conn, error) {
177+
return (&net.Dialer{}).DialContext(ctx, network, dialAddr)
178+
}
179+
180+
cloned := t.Clone()
181+
cloned.DialContext = dial
182+
if cloned.TLSClientConfig == nil {
183+
cloned.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12} //nolint:gosec
184+
} else {
185+
cloned.TLSClientConfig = cloned.TLSClientConfig.Clone()
186+
}
187+
cloned.TLSClientConfig.ServerName = serverName
188+
return cloned, nil
189+
}
190+
41191
// Ping pings a registry and returns authentication information.
42192
func Ping(ctx context.Context, reg reference.Registry, transport http.RoundTripper) (*PingResponse, error) {
43193
if transport == nil {
@@ -104,12 +254,21 @@ func parseWWWAuthenticate(header string) WWWAuthenticate {
104254
}
105255

106256
// Exchange exchanges credentials for a bearer token.
107-
func Exchange(ctx context.Context, reg reference.Registry, auth authn.Authenticator, transport http.RoundTripper, scopes []string, pr *PingResponse) (*Token, error) {
257+
func Exchange(ctx context.Context, _ reference.Registry, auth authn.Authenticator, transport http.RoundTripper, scopes []string, pr *PingResponse) (*Token, error) {
108258
if transport == nil {
109259
transport = http.DefaultTransport
110260
}
111261

112-
client := &http.Client{Transport: transport}
262+
dialAddr, hostname, err := resolveAndValidateRealm(pr.WWWAuthenticate.Realm)
263+
if err != nil {
264+
return nil, fmt.Errorf("realm URL rejected: %w", err)
265+
}
266+
267+
safeTransport, err := buildSafeTransport(transport, dialAddr, hostname)
268+
if err != nil {
269+
return nil, fmt.Errorf("failed to build safe transport: %w", err)
270+
}
271+
client := &http.Client{Transport: safeTransport}
113272

114273
// Build token request URL
115274
tokenURL, err := url.Parse(pr.WWWAuthenticate.Realm)
@@ -150,7 +309,11 @@ func Exchange(ctx context.Context, reg reference.Registry, auth authn.Authentica
150309

151310
if resp.StatusCode != http.StatusOK {
152311
body, _ := io.ReadAll(resp.Body)
153-
return nil, fmt.Errorf("token request failed with status %d: %s", resp.StatusCode, string(body))
312+
slog.DebugContext(ctx, "token request failed",
313+
"status", resp.StatusCode,
314+
"body", string(body),
315+
)
316+
return nil, fmt.Errorf("token request failed: unexpected status %d from token endpoint", resp.StatusCode)
154317
}
155318

156319
var token Token
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package remote_test
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"strings"
9+
"sync/atomic"
10+
"testing"
11+
12+
"github.com/docker/model-runner/pkg/distribution/oci/reference"
13+
"github.com/docker/model-runner/pkg/distribution/oci/remote"
14+
)
15+
16+
// emptyRegistry is a zero-value reference.Registry.
17+
// reference.Registry is a concrete struct (not an interface), so nil cannot
18+
// be used; the zero value is a safe placeholder when registry-specific
19+
// behaviour is not exercised by the test.
20+
var emptyRegistry reference.Registry
21+
22+
// pingResponseForRealm returns a *remote.PingResponse whose Realm field is
23+
// set to the given URL. This simulates the result of calling Ping() against
24+
// a malicious registry that advertises an attacker-controlled realm.
25+
func pingResponseForRealm(realm string) *remote.PingResponse {
26+
return &remote.PingResponse{
27+
WWWAuthenticate: remote.WWWAuthenticate{
28+
Realm: realm,
29+
Service: "evil-registry",
30+
Scope: "repository:evil/model:pull",
31+
},
32+
}
33+
}
34+
35+
// TestExchangeSSRF_RequestSentToRealmURL verifies that prevents Exchange() from
36+
// contacting internal services via a malicious realm URL.
37+
func TestExchangeSSRF_RequestSentToRealmURL(t *testing.T) {
38+
var hitCount atomic.Int32
39+
40+
// "Internal service" — simulates a host-local endpoint (127.0.0.1) that
41+
// should never be reachable via a registry token-exchange flow.
42+
internalService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
43+
hitCount.Add(1)
44+
w.Header().Set("Content-Type", "application/json")
45+
w.WriteHeader(http.StatusForbidden)
46+
fmt.Fprintln(w, `{"error":"not a token endpoint"}`)
47+
}))
48+
defer internalService.Close()
49+
50+
// Build a PingResponse whose Realm points at the internal service —
51+
// this is what a malicious registry would return in its 401 response.
52+
pr := pingResponseForRealm(internalService.URL + "/internal/credentials")
53+
54+
_, err := remote.Exchange(t.Context(), emptyRegistry, nil, nil, []string{"repository:x:pull"}, pr)
55+
if err == nil {
56+
t.Fatal("Exchange() should have returned an error when the realm URL resolves to a loopback address")
57+
}
58+
if hitCount.Load() > 0 {
59+
t.Errorf("SSRF not blocked: Exchange() sent %d request(s) to the internal service at %s — realm URL validation should have rejected 127.0.0.1 before making any HTTP request", hitCount.Load(), internalService.URL)
60+
}
61+
if !strings.Contains(err.Error(), "realm URL rejected") {
62+
t.Errorf("expected error to mention realm URL rejection, got: %q", err.Error())
63+
}
64+
}
65+
66+
// TestExchangeSSRF_SensitiveBodyNotReflectedInError verifies that Exchange()
67+
// does NOT include a token-endpoint response body in the error it returns to
68+
// the caller.
69+
func TestExchangeSSRF_SensitiveBodyNotReflectedInError(t *testing.T) {
70+
sensitiveData := map[string]string{
71+
"db_password": "s3cret_example",
72+
"internal_api_key": "sk-example-key",
73+
"proof": "THIS_DATA_READ_VIA_SSRF",
74+
}
75+
sensitiveJSON, err := json.Marshal(sensitiveData)
76+
if err != nil {
77+
t.Fatalf("failed to marshal test data: %v", err)
78+
}
79+
80+
// "Internal service" returns a non-200 response containing sensitive data.
81+
internalService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
82+
w.Header().Set("Content-Type", "application/json")
83+
w.WriteHeader(http.StatusUnauthorized)
84+
w.Write(sensitiveJSON) //nolint:errcheck
85+
}))
86+
defer internalService.Close()
87+
88+
pr := pingResponseForRealm(internalService.URL + "/admin/api/keys")
89+
90+
_, err = remote.Exchange(t.Context(), emptyRegistry, nil, nil, []string{"repository:x:pull"}, pr)
91+
if err == nil {
92+
t.Fatal("expected an error from Exchange() when realm URL is blocked or token endpoint returns 401, got nil")
93+
}
94+
95+
errMsg := err.Error()
96+
97+
// The error must indicate realm rejection, not a remote HTTP status code.
98+
if !strings.Contains(errMsg, "realm URL rejected") {
99+
t.Errorf("expected error to mention realm URL rejection (Fix 2a), got: %q", errMsg)
100+
}
101+
102+
// the response body must never appear in the returned error
103+
for _, sensitive := range []string{"s3cret_example", "sk-example-key", "THIS_DATA_READ_VIA_SSRF"} {
104+
if strings.Contains(errMsg, sensitive) {
105+
t.Errorf("BODY REFLECTION vulnerability: error contains sensitive value %q — the response body must not be included in errors returned to the caller (log at debug level instead)", sensitive)
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)