Skip to content

Commit 3bb52e7

Browse files
committed
Extend TLS options
* create proper tls config for https proxy dialer * refactor common tls routines into tlsutil * allow mTLS for upstream proxies * customizable curve list
1 parent 1d8369d commit 3bb52e7

File tree

3 files changed

+276
-110
lines changed

3 files changed

+276
-110
lines changed

dialer/upstream.go

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,105 @@ import (
1616
"sync"
1717

1818
xproxy "golang.org/x/net/proxy"
19+
20+
"github.com/SenseUnit/dumbproxy/tlsutil"
1921
)
2022

2123
type HTTPProxyDialer struct {
22-
address string
23-
tls bool
24-
userinfo *url.Userinfo
25-
next Dialer
24+
address string
25+
tlsConfig *tls.Config
26+
userinfo *url.Userinfo
27+
next Dialer
2628
}
2729

28-
func NewHTTPProxyDialer(address string, tls bool, userinfo *url.Userinfo, next LegacyDialer) *HTTPProxyDialer {
30+
func NewHTTPProxyDialer(address string, tlsConfig *tls.Config, userinfo *url.Userinfo, next LegacyDialer) *HTTPProxyDialer {
2931
return &HTTPProxyDialer{
30-
address: address,
31-
tls: tls,
32-
next: MaybeWrapWithContextDialer(next),
33-
userinfo: userinfo,
32+
address: address,
33+
tlsConfig: tlsConfig,
34+
next: MaybeWrapWithContextDialer(next),
35+
userinfo: userinfo,
3436
}
3537
}
3638

3739
func HTTPProxyDialerFromURL(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) {
3840
host := u.Hostname()
3941
port := u.Port()
40-
tls := false
42+
params, err := url.ParseQuery(u.RawQuery)
43+
if err != nil {
44+
return nil, fmt.Errorf("unable to parse query string of proxy specification URL %q: %w", u.String(), err)
45+
}
4146

47+
var tlsConfig *tls.Config
4248
switch strings.ToLower(u.Scheme) {
4349
case "http":
4450
if port == "" {
4551
port = "80"
4652
}
4753
case "https":
48-
tls = true
4954
if port == "" {
5055
port = "443"
5156
}
57+
tlsConfig = &tls.Config{
58+
ServerName: host,
59+
}
60+
if params.Has("cafile") {
61+
roots, err := tlsutil.LoadCAfile(params.Get("cafile"))
62+
if err != nil {
63+
return nil, err
64+
}
65+
tlsConfig.RootCAs = roots
66+
}
67+
if params.Has("sni") {
68+
tlsConfig.ServerName = params.Get("sni")
69+
tlsConfig.InsecureSkipVerify = true
70+
tlsConfig.VerifyConnection = tlsutil.ExpectPeerName(host, tlsConfig.RootCAs)
71+
}
72+
if params.Has("peername") {
73+
tlsConfig.InsecureSkipVerify = true
74+
tlsConfig.VerifyConnection = tlsutil.ExpectPeerName(params.Get("peername"), tlsConfig.RootCAs)
75+
}
76+
if params.Has("cert") {
77+
cert, err := tls.LoadX509KeyPair(params.Get("cert"), params.Get("key"))
78+
if err != nil {
79+
return nil, err
80+
}
81+
tlsConfig.Certificates = []tls.Certificate{cert}
82+
}
83+
if params.Has("ciphers") {
84+
cipherList, err := tlsutil.ParseCipherList(params.Get("ciphers"))
85+
if err != nil {
86+
return nil, err
87+
}
88+
tlsConfig.CipherSuites = cipherList
89+
}
90+
if params.Has("curves") {
91+
curveList, err := tlsutil.ParseCurveList(params.Get("curves"))
92+
if err != nil {
93+
return nil, err
94+
}
95+
tlsConfig.CurvePreferences = curveList
96+
}
97+
if params.Has("min-tls-version") {
98+
ver, err := tlsutil.ParseVersion(params.Get("min-tls-version"))
99+
if err != nil {
100+
return nil, err
101+
}
102+
tlsConfig.MinVersion = ver
103+
}
104+
if params.Has("max-tls-version") {
105+
ver, err := tlsutil.ParseVersion(params.Get("max-tls-version"))
106+
if err != nil {
107+
return nil, err
108+
}
109+
tlsConfig.MaxVersion = ver
110+
}
52111
default:
53112
return nil, errors.New("unsupported proxy type")
54113
}
55114

56115
address := net.JoinHostPort(host, port)
57116

58-
return NewHTTPProxyDialer(address, tls, u.User, next), nil
117+
return NewHTTPProxyDialer(address, tlsConfig, u.User, next), nil
59118
}
60119

61120
func (d *HTTPProxyDialer) Dial(network, address string) (net.Conn, error) {
@@ -72,14 +131,8 @@ func (d *HTTPProxyDialer) DialContext(ctx context.Context, network, address stri
72131
if err != nil {
73132
return nil, fmt.Errorf("proxy dialer is unable to make connection: %w", err)
74133
}
75-
if d.tls {
76-
hostname, _, err := net.SplitHostPort(d.address)
77-
if err != nil {
78-
hostname = address
79-
}
80-
conn = tls.Client(conn, &tls.Config{
81-
ServerName: hostname,
82-
})
134+
if d.tlsConfig != nil {
135+
conn = tls.Client(conn, d.tlsConfig)
83136
}
84137

85138
stopGuardEvent := make(chan struct{})

main.go

Lines changed: 43 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@ import (
44
"bytes"
55
"crypto/rand"
66
"crypto/tls"
7-
"crypto/x509"
87
"encoding/base64"
98
"encoding/binary"
109
"encoding/hex"
1110
"errors"
1211
"flag"
1312
"fmt"
14-
"io/ioutil"
1513
"log"
1614
"net"
1715
"net/http"
@@ -37,6 +35,7 @@ import (
3735
"github.com/SenseUnit/dumbproxy/forward"
3836
"github.com/SenseUnit/dumbproxy/handler"
3937
clog "github.com/SenseUnit/dumbproxy/log"
38+
"github.com/SenseUnit/dumbproxy/tlsutil"
4039
proxyproto "github.com/pires/go-proxyproto"
4140

4241
_ "golang.org/x/crypto/x509roots/fallback"
@@ -118,61 +117,16 @@ func (l *PrefixList) Value() []netip.Prefix {
118117
type TLSVersionArg uint16
119118

120119
func (a *TLSVersionArg) Set(s string) error {
121-
var ver uint16
122-
switch strings.ToUpper(s) {
123-
case "TLS10":
124-
ver = tls.VersionTLS10
125-
case "TLS11":
126-
ver = tls.VersionTLS11
127-
case "TLS12":
128-
ver = tls.VersionTLS12
129-
case "TLS13":
130-
ver = tls.VersionTLS13
131-
case "TLS1.0":
132-
ver = tls.VersionTLS10
133-
case "TLS1.1":
134-
ver = tls.VersionTLS11
135-
case "TLS1.2":
136-
ver = tls.VersionTLS12
137-
case "TLS1.3":
138-
ver = tls.VersionTLS13
139-
case "10":
140-
ver = tls.VersionTLS10
141-
case "11":
142-
ver = tls.VersionTLS11
143-
case "12":
144-
ver = tls.VersionTLS12
145-
case "13":
146-
ver = tls.VersionTLS13
147-
case "1.0":
148-
ver = tls.VersionTLS10
149-
case "1.1":
150-
ver = tls.VersionTLS11
151-
case "1.2":
152-
ver = tls.VersionTLS12
153-
case "1.3":
154-
ver = tls.VersionTLS13
155-
case "":
156-
default:
157-
return fmt.Errorf("unknown TLS version %q", s)
120+
ver, err := tlsutil.ParseVersion(s)
121+
if err != nil {
122+
return err
158123
}
159124
*a = TLSVersionArg(ver)
160125
return nil
161126
}
162127

163128
func (a *TLSVersionArg) String() string {
164-
switch *a {
165-
case tls.VersionTLS10:
166-
return "TLS10"
167-
case tls.VersionTLS11:
168-
return "TLS11"
169-
case tls.VersionTLS12:
170-
return "TLS12"
171-
case tls.VersionTLS13:
172-
return "TLS13"
173-
default:
174-
return fmt.Sprintf("%#04x", *a)
175-
}
129+
return tlsutil.FormatVersion(uint16(*a))
176130
}
177131

178132
type proxyArg struct {
@@ -224,7 +178,9 @@ type CLIArgs struct {
224178
verbosity int
225179
cert, key, cafile string
226180
list_ciphers bool
181+
list_curves bool
227182
ciphers string
183+
curves string
228184
disableHTTP2 bool
229185
showVersion bool
230186
autocert bool
@@ -292,7 +248,9 @@ func parse_args() CLIArgs {
292248
flag.StringVar(&args.key, "key", "", "key for TLS certificate")
293249
flag.StringVar(&args.cafile, "cafile", "", "CA file to authenticate clients with certificates")
294250
flag.BoolVar(&args.list_ciphers, "list-ciphers", false, "list ciphersuites")
251+
flag.BoolVar(&args.list_curves, "list-curves", false, "list key exchange curves")
295252
flag.StringVar(&args.ciphers, "ciphers", "", "colon-separated list of enabled ciphers")
253+
flag.StringVar(&args.curves, "curves", "", "colon-separated list of enabled key exchange curves")
296254
flag.BoolVar(&args.disableHTTP2, "disable-http2", false, "disable HTTP2")
297255
flag.BoolVar(&args.showVersion, "version", false, "show program version and exit")
298256
flag.BoolVar(&args.autocert, "autocert", false, "issue TLS certificates automatically")
@@ -374,6 +332,11 @@ func run() int {
374332
return 0
375333
}
376334

335+
if args.list_curves {
336+
list_curves()
337+
return 0
338+
}
339+
377340
if args.passwd != "" {
378341
if err := passwd(args.passwd, args.passwdCost, args.positionalArgs...); err != nil {
379342
log.Fatalf("can't set password: %v", err)
@@ -570,7 +533,8 @@ func run() int {
570533

571534
if args.cert != "" {
572535
cfg, err1 := makeServerTLSConfig(args.cert, args.key, args.cafile,
573-
args.ciphers, uint16(args.minTLSVersion), uint16(args.maxTLSVersion), !args.disableHTTP2)
536+
args.ciphers, args.curves,
537+
uint16(args.minTLSVersion), uint16(args.maxTLSVersion), !args.disableHTTP2)
574538
if err1 != nil {
575539
mainLogger.Critical("TLS config construction failed: %v", err1)
576540
return 3
@@ -629,7 +593,7 @@ func run() int {
629593
}()
630594
}
631595
cfg := m.TLSConfig()
632-
cfg, err = updateServerTLSConfig(cfg, args.cafile, args.ciphers,
596+
cfg, err = updateServerTLSConfig(cfg, args.cafile, args.ciphers, args.curves,
633597
uint16(args.minTLSVersion), uint16(args.maxTLSVersion), !args.disableHTTP2)
634598
if err != nil {
635599
mainLogger.Critical("TLS config construction failed: %v", err)
@@ -657,7 +621,7 @@ func run() int {
657621
return 0
658622
}
659623

660-
func makeServerTLSConfig(certfile, keyfile, cafile, ciphers string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) {
624+
func makeServerTLSConfig(certfile, keyfile, cafile, ciphers, curves string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) {
661625
cfg := tls.Config{
662626
MinVersion: minVer,
663627
MaxVersion: maxVer,
@@ -668,18 +632,21 @@ func makeServerTLSConfig(certfile, keyfile, cafile, ciphers string, minVer, maxV
668632
}
669633
cfg.Certificates = []tls.Certificate{cert}
670634
if cafile != "" {
671-
roots := x509.NewCertPool()
672-
certs, err := ioutil.ReadFile(cafile)
635+
roots, err := tlsutil.LoadCAfile(cafile)
673636
if err != nil {
674637
return nil, err
675638
}
676-
if ok := roots.AppendCertsFromPEM(certs); !ok {
677-
return nil, errors.New("Failed to load CA certificates")
678-
}
679639
cfg.ClientCAs = roots
680640
cfg.ClientAuth = tls.VerifyClientCertIfGiven
681641
}
682-
cfg.CipherSuites = makeCipherList(ciphers)
642+
cfg.CipherSuites, err = tlsutil.ParseCipherList(ciphers)
643+
if err != nil {
644+
return nil, err
645+
}
646+
cfg.CurvePreferences, err = tlsutil.ParseCurveList(curves)
647+
if err != nil {
648+
return nil, err
649+
}
683650
if h2 {
684651
cfg.NextProtos = []string{"h2", "http/1.1"}
685652
} else {
@@ -688,20 +655,24 @@ func makeServerTLSConfig(certfile, keyfile, cafile, ciphers string, minVer, maxV
688655
return &cfg, nil
689656
}
690657

691-
func updateServerTLSConfig(cfg *tls.Config, cafile, ciphers string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) {
658+
func updateServerTLSConfig(cfg *tls.Config, cafile, ciphers, curves string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) {
692659
if cafile != "" {
693-
roots := x509.NewCertPool()
694-
certs, err := ioutil.ReadFile(cafile)
660+
roots, err := tlsutil.LoadCAfile(cafile)
695661
if err != nil {
696662
return nil, err
697663
}
698-
if ok := roots.AppendCertsFromPEM(certs); !ok {
699-
return nil, errors.New("Failed to load CA certificates")
700-
}
701664
cfg.ClientCAs = roots
702665
cfg.ClientAuth = tls.VerifyClientCertIfGiven
703666
}
704-
cfg.CipherSuites = makeCipherList(ciphers)
667+
var err error
668+
cfg.CipherSuites, err = tlsutil.ParseCipherList(ciphers)
669+
if err != nil {
670+
return nil, err
671+
}
672+
cfg.CurvePreferences, err = tlsutil.ParseCurveList(curves)
673+
if err != nil {
674+
return nil, err
675+
}
705676
if h2 {
706677
cfg.NextProtos = []string{"h2", "http/1.1", "acme-tls/1"}
707678
} else {
@@ -712,33 +683,15 @@ func updateServerTLSConfig(cfg *tls.Config, cafile, ciphers string, minVer, maxV
712683
return cfg, nil
713684
}
714685

715-
func makeCipherList(ciphers string) []uint16 {
716-
if ciphers == "" {
717-
return nil
718-
}
719-
720-
cipherIDs := make(map[string]uint16)
686+
func list_ciphers() {
721687
for _, cipher := range tls.CipherSuites() {
722-
cipherIDs[cipher.Name] = cipher.ID
723-
}
724-
725-
cipherNameList := strings.Split(ciphers, ":")
726-
cipherIDList := make([]uint16, 0, len(cipherNameList))
727-
728-
for _, name := range cipherNameList {
729-
id, ok := cipherIDs[name]
730-
if !ok {
731-
log.Printf("WARNING: Unknown cipher \"%s\"", name)
732-
}
733-
cipherIDList = append(cipherIDList, id)
688+
fmt.Println(cipher.Name)
734689
}
735-
736-
return cipherIDList
737690
}
738691

739-
func list_ciphers() {
740-
for _, cipher := range tls.CipherSuites() {
741-
fmt.Println(cipher.Name)
692+
func list_curves() {
693+
for _, curve := range tlsutil.Curves() {
694+
fmt.Println(curve.String())
742695
}
743696
}
744697

0 commit comments

Comments
 (0)