@@ -2,9 +2,12 @@ package remote
22
33import (
44 "context"
5+ "crypto/tls"
56 "encoding/json"
67 "fmt"
78 "io"
9+ "log/slog"
10+ "net"
811 "net/http"
912 "net/url"
1013 "strings"
@@ -38,6 +41,153 @@ type Token struct {
3841 ExpiresIn int `json:"expires_in"`
3942}
4043
44+ // privateOrLoopbackCIDRs lists IP ranges that must never be contacted as a
45+ // token-exchange realm. Allowing requests to these addresses would let a
46+ // malicious registry pivot Model Runner into an internal-service proxy
47+ // (SSRF), reaching endpoints that are not accessible from the public internet.
48+ var privateOrLoopbackCIDRs = func () []* net.IPNet {
49+ cidrs := []string {
50+ "127.0.0.0/8" , // loopback IPv4
51+ "::1/128" , // loopback IPv6
52+ "169.254.0.0/16" , // link-local IPv4 / AWS EC2 instance-metadata
53+ "fe80::/10" , // link-local IPv6
54+ "10.0.0.0/8" , // RFC-1918 private
55+ "172.16.0.0/12" , // RFC-1918 private
56+ "192.168.0.0/16" , // RFC-1918 private
57+ "fc00::/7" , // IPv6 ULA
58+ }
59+ nets := make ([]* net.IPNet , 0 , len (cidrs ))
60+ for _ , cidr := range cidrs {
61+ _ , n , err := net .ParseCIDR (cidr )
62+ if err != nil {
63+ // These are hardcoded compile-time constants; a parse failure
64+ // indicates a programmer error (e.g. a typo). Panic immediately
65+ // so the mistake is caught at startup rather than silently
66+ // weakening the SSRF blocklist.
67+ panic (fmt .Sprintf ("internal error: failed to parse hardcoded CIDR %q: %v" , cidr , err ))
68+ }
69+ nets = append (nets , n )
70+ }
71+ return nets
72+ }()
73+
74+ // internalHostnames lists hostnames that must never be used as a realm,
75+ // regardless of what IP address they resolve to.
76+ var internalHostnames = []string {
77+ "localhost" ,
78+ "host.docker.internal" ,
79+ "model-runner.docker.internal" ,
80+ "gateway.docker.internal" ,
81+ }
82+
83+ // isDisallowedIP reports whether ip falls in any of the private/loopback/
84+ // link-local ranges that must not be contacted as a token-exchange realm.
85+ func isDisallowedIP (ip net.IP ) bool {
86+ for _ , cidr := range privateOrLoopbackCIDRs {
87+ if cidr .Contains (ip ) {
88+ return true
89+ }
90+ }
91+ return false
92+ }
93+
94+ // resolveAndValidateRealm parses the realm URL, validates the hostname against
95+ // a static blocklist, and resolves it to a dial address (ip:port) that is safe
96+ // to connect to. By returning the resolved IP, callers can use a custom
97+ // DialContext to connect to that exact address — preventing DNS-rebinding
98+ // attacks where a malicious DNS server could return different IPs for
99+ // successive lookups (TOCTOU).
100+ //
101+ // If hostname is a literal IP address it is validated directly without
102+ // triggering a DNS lookup.
103+ func resolveAndValidateRealm (rawURL string ) (dialAddr , hostname string , err error ) {
104+ u , err := url .Parse (rawURL )
105+ if err != nil {
106+ return "" , "" , fmt .Errorf ("invalid realm URL: %w" , err )
107+ }
108+
109+ hostname = u .Hostname ()
110+ port := u .Port ()
111+ if port == "" {
112+ switch u .Scheme {
113+ case "https" :
114+ port = "443"
115+ default :
116+ port = "80"
117+ }
118+ }
119+
120+ // Block well-known internal hostnames regardless of DNS resolution.
121+ for _ , internal := range internalHostnames {
122+ if strings .EqualFold (hostname , internal ) {
123+ return "" , "" , fmt .Errorf ("realm URL hostname %q is not allowed" , hostname )
124+ }
125+ }
126+
127+ // If the hostname is a literal IP address, validate it directly without a
128+ // DNS lookup — there is no DNS to rebind.
129+ if ip := net .ParseIP (hostname ); ip != nil {
130+ if isDisallowedIP (ip ) {
131+ return "" , "" , fmt .Errorf ("realm URL contains a disallowed IP address %s" , hostname )
132+ }
133+ return net .JoinHostPort (hostname , port ), hostname , nil
134+ }
135+
136+ // Resolve the hostname and validate every returned address. Using the
137+ // resolved IP as the dial address prevents DNS rebinding: the same IP that
138+ // passed validation is the one that will be used for the connection.
139+ ips , err := net .LookupHost (hostname )
140+ if err != nil {
141+ return "" , "" , fmt .Errorf ("resolving realm hostname %q: %w" , hostname , err )
142+ }
143+ if len (ips ) == 0 {
144+ return "" , "" , fmt .Errorf ("realm hostname %q resolved to no addresses" , hostname )
145+ }
146+ for _ , ipStr := range ips {
147+ ip := net .ParseIP (ipStr )
148+ if ip == nil {
149+ continue
150+ }
151+ if isDisallowedIP (ip ) {
152+ return "" , "" , fmt .Errorf ("realm URL resolves to a disallowed address %s" , ipStr )
153+ }
154+ }
155+
156+ // All resolved IPs passed validation. Use the first one as the dial
157+ // address so the HTTP client never performs a second DNS lookup.
158+ return net .JoinHostPort (ips [0 ], port ), hostname , nil
159+ }
160+
161+ // buildSafeTransport wraps base with a custom DialContext that connects
162+ // directly to dialAddr (a pre-validated "ip:port") instead of relying on DNS
163+ // resolution. This closes the TOCTOU window between realm URL validation and
164+ // the actual connection. For TLS connections, serverName is set as the SNI
165+ // value so that certificate validation still uses the original hostname.
166+ func buildSafeTransport (base http.RoundTripper , dialAddr , serverName string ) (http.RoundTripper , error ) {
167+ if base == nil {
168+ base = http .DefaultTransport
169+ }
170+
171+ t , ok := base .(* http.Transport )
172+ if ! ok {
173+ return nil , fmt .Errorf ("cannot build safe transport from base of type %T; only *http.Transport is supported" , base )
174+ }
175+
176+ dial := func (ctx context.Context , network , _ string ) (net.Conn , error ) {
177+ return (& net.Dialer {}).DialContext (ctx , network , dialAddr )
178+ }
179+
180+ cloned := t .Clone ()
181+ cloned .DialContext = dial
182+ if cloned .TLSClientConfig == nil {
183+ cloned .TLSClientConfig = & tls.Config {MinVersion : tls .VersionTLS12 } //nolint:gosec
184+ } else {
185+ cloned .TLSClientConfig = cloned .TLSClientConfig .Clone ()
186+ }
187+ cloned .TLSClientConfig .ServerName = serverName
188+ return cloned , nil
189+ }
190+
41191// Ping pings a registry and returns authentication information.
42192func Ping (ctx context.Context , reg reference.Registry , transport http.RoundTripper ) (* PingResponse , error ) {
43193 if transport == nil {
@@ -104,12 +254,21 @@ func parseWWWAuthenticate(header string) WWWAuthenticate {
104254}
105255
106256// Exchange exchanges credentials for a bearer token.
107- func Exchange (ctx context.Context , reg reference.Registry , auth authn.Authenticator , transport http.RoundTripper , scopes []string , pr * PingResponse ) (* Token , error ) {
257+ func Exchange (ctx context.Context , _ reference.Registry , auth authn.Authenticator , transport http.RoundTripper , scopes []string , pr * PingResponse ) (* Token , error ) {
108258 if transport == nil {
109259 transport = http .DefaultTransport
110260 }
111261
112- client := & http.Client {Transport : transport }
262+ dialAddr , hostname , err := resolveAndValidateRealm (pr .WWWAuthenticate .Realm )
263+ if err != nil {
264+ return nil , fmt .Errorf ("realm URL rejected: %w" , err )
265+ }
266+
267+ safeTransport , err := buildSafeTransport (transport , dialAddr , hostname )
268+ if err != nil {
269+ return nil , fmt .Errorf ("failed to build safe transport: %w" , err )
270+ }
271+ client := & http.Client {Transport : safeTransport }
113272
114273 // Build token request URL
115274 tokenURL , err := url .Parse (pr .WWWAuthenticate .Realm )
@@ -150,7 +309,11 @@ func Exchange(ctx context.Context, reg reference.Registry, auth authn.Authentica
150309
151310 if resp .StatusCode != http .StatusOK {
152311 body , _ := io .ReadAll (resp .Body )
153- return nil , fmt .Errorf ("token request failed with status %d: %s" , resp .StatusCode , string (body ))
312+ slog .DebugContext (ctx , "token request failed" ,
313+ "status" , resp .StatusCode ,
314+ "body" , string (body ),
315+ )
316+ return nil , fmt .Errorf ("token request failed: unexpected status %d from token endpoint" , resp .StatusCode )
154317 }
155318
156319 var token Token
0 commit comments