@@ -28,6 +28,7 @@ import (
2828 "strings"
2929 "sync"
3030 "sync/atomic"
31+ "syscall"
3132 "time"
3233
3334 "github.com/matrix-org/gomatrix"
@@ -54,13 +55,15 @@ type UserInfo struct {
5455}
5556
5657type clientOptions struct {
57- transport http.RoundTripper
58- dnsCache * DNSCache
59- timeout time.Duration
60- skipVerify bool
61- keepAlives bool
62- wellKnownSRV bool
63- userAgent string
58+ transport http.RoundTripper
59+ dnsCache * DNSCache
60+ timeout time.Duration
61+ skipVerify bool
62+ keepAlives bool
63+ wellKnownSRV bool
64+ userAgent string
65+ allowNetworks []string
66+ denyNetworks []string
6467}
6568
6669// ClientOption are supplied to NewClient or NewFederationClient.
@@ -82,6 +85,8 @@ func NewClient(options ...ClientOption) *Client {
8285 clientOpts .dnsCache ,
8386 clientOpts .keepAlives ,
8487 clientOpts .wellKnownSRV ,
88+ clientOpts .allowNetworks ,
89+ clientOpts .denyNetworks ,
8590 )
8691 }
8792 client := & Client {
@@ -152,6 +157,15 @@ func WithUserAgent(userAgent string) ClientOption {
152157 }
153158}
154159
160+ // WithAllowDenyNetworks sets the allowed and denied networks for the http client. By default,
161+ // all networks are allowed. The deny list is checked before the allow list.
162+ func WithAllowDenyNetworks (allowCIDRs []string , denyCIDRs []string ) ClientOption {
163+ return func (options * clientOptions ) {
164+ options .allowNetworks = allowCIDRs
165+ options .denyNetworks = denyCIDRs
166+ }
167+ }
168+
155169const destinationTripperLifetime = time .Minute * 5 // how long to keep an entry
156170const destinationTripperReapInterval = time .Minute // how often to check for dead entries
157171
@@ -165,15 +179,17 @@ type destinationTripper struct {
165179 dnsCache * DNSCache
166180 keepAlives bool
167181 wellKnownSRV bool
182+ dialer * net.Dialer
168183}
169184
170- func newDestinationTripper (skipVerify bool , dnsCache * DNSCache , keepAlives , wellKnownSRV bool ) * destinationTripper {
185+ func newDestinationTripper (skipVerify bool , dnsCache * DNSCache , keepAlives , wellKnownSRV bool , allowCIDRs [] string , denyCIDRs [] string ) * destinationTripper {
171186 tripper := & destinationTripper {
172187 transports : make (map [string ]* destinationTripperTransport ),
173188 skipVerify : skipVerify ,
174189 dnsCache : dnsCache ,
175190 keepAlives : keepAlives ,
176191 wellKnownSRV : wellKnownSRV ,
192+ dialer : newDestinationTripperDialer (allowCIDRs , denyCIDRs ),
177193 }
178194 time .AfterFunc (destinationTripperReapInterval , tripper .reaper )
179195 return tripper
@@ -195,11 +211,71 @@ func (f *destinationTripper) reaper() {
195211 time .AfterFunc (destinationTripperReapInterval , f .reaper )
196212}
197213
198- // destinationTripperDialer enforces dial timeouts on the federation requests. If
214+ // newDestinationTripperDialer creates a dialer which enforces dial timeouts on the federation requests. If
199215// the TCP connection doesn't complete within 5 seconds, it's probably just not
200216// going to.
201- var destinationTripperDialer = & net.Dialer {
202- Timeout : time .Second * 5 ,
217+ // The dialer can also be limited to CIDR ranges, if allow or deny networks is non-empty.
218+ func newDestinationTripperDialer (allowNetworks []string , denyNetworks []string ) * net.Dialer {
219+ if len (allowNetworks ) == 0 && len (denyNetworks ) == 0 {
220+ return & net.Dialer {
221+ Timeout : time .Second * 5 ,
222+ }
223+ }
224+
225+ return & net.Dialer {
226+ Timeout : time .Second * 5 ,
227+ ControlContext : allowDenyNetworksControl (allowNetworks , denyNetworks ),
228+ }
229+ }
230+
231+ // allowDenyNetworksControl is used to allow/deny access to certain networks
232+ func allowDenyNetworksControl (allowNetworks , denyNetworks []string ) func (_ context.Context , network string , address string , conn syscall.RawConn ) error {
233+ return func (_ context.Context , network string , address string , conn syscall.RawConn ) error {
234+ if network != "tcp4" && network != "tcp6" {
235+ return fmt .Errorf ("%s is not a safe network type" , network )
236+ }
237+
238+ host , _ , err := net .SplitHostPort (address )
239+ if err != nil {
240+ return fmt .Errorf ("%s is not a valid host/port pair: %s" , address , err )
241+ }
242+
243+ ipaddress := net .ParseIP (host )
244+ if ipaddress == nil {
245+ return fmt .Errorf ("%s is not a valid IP address" , host )
246+ }
247+
248+ if ! isAllowed (ipaddress , allowNetworks , denyNetworks ) {
249+ return fmt .Errorf ("%s is denied" , address )
250+ }
251+
252+ return nil // allow connection
253+ }
254+ }
255+
256+ func isAllowed (ip net.IP , allowCIDRs []string , denyCIDRs []string ) bool {
257+ if inRange (ip , denyCIDRs ) {
258+ return false
259+ }
260+ if inRange (ip , allowCIDRs ) {
261+ return true
262+ }
263+ return false // "should never happen"
264+ }
265+
266+ func inRange (ip net.IP , CIDRs []string ) bool {
267+ for i := 0 ; i < len (CIDRs ); i ++ {
268+ cidr := CIDRs [i ]
269+ _ , network , err := net .ParseCIDR (cidr )
270+ if err != nil {
271+ return false
272+ }
273+ if network .Contains (ip ) {
274+ return true
275+ }
276+ }
277+
278+ return false
203279}
204280
205281type destinationTripperTransport struct {
@@ -213,7 +289,7 @@ type destinationTripperTransport struct {
213289// We need to use one transport per TLS server name (instead of giving our round
214290// tripper a single transport) because there is no way to specify the TLS
215291// ServerName on a per-connection basis.
216- func (f * destinationTripper ) getTransport (tlsServerName string ) http.RoundTripper {
292+ func (f * destinationTripper ) getTransport (tlsServerName string , dialer * net. Dialer ) http.RoundTripper {
217293 f .transportsMutex .Lock ()
218294 defer f .transportsMutex .Unlock ()
219295
@@ -230,8 +306,8 @@ func (f *destinationTripper) getTransport(tlsServerName string) http.RoundTrippe
230306 InsecureSkipVerify : f .skipVerify ,
231307 ClientSessionCache : tls .NewLRUClientSessionCache (0 ), // 0 = use default
232308 },
233- Dial : destinationTripperDialer .Dial , // nolint: staticcheck
234- DialContext : destinationTripperDialer .DialContext ,
309+ Dial : dialer .Dial , // nolint: staticcheck
310+ DialContext : dialer .DialContext ,
235311 Proxy : http .ProxyFromEnvironment ,
236312 ForceAttemptHTTP2 : true , // if we can multiplex requests over HTTP/2, we should
237313 },
@@ -296,7 +372,7 @@ retryResolution:
296372 u := makeHTTPSURL (r .URL , result .Destination )
297373 r .URL = & u
298374 r .Host = string (result .Host )
299- resp , err = f .getTransport (result .TLSServerName ).RoundTrip (r )
375+ resp , err = f .getTransport (result .TLSServerName , f . dialer ).RoundTrip (r )
300376 if err == nil {
301377 return resp , nil
302378 }
0 commit comments