Skip to content

Commit 6ec158f

Browse files
authored
Merge pull request #151 from SenseUnit/cust_alpn
ALPN option
2 parents 4c13017 + a3fd5ff commit 6ec158f

File tree

2 files changed

+60
-64
lines changed

2 files changed

+60
-64
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ Usage of /home/user/go/bin/dumbproxy:
553553
amount of time allowed to read request headers (default 30s)
554554
-shutdown-timeout duration
555555
grace period during server shutdown (default 1s)
556+
-tls-alpn-enabled
557+
enable application protocol negotiation with TLS ALPN extension (default true)
556558
-unix-sock-mode value
557559
set file mode for bound unix socket
558560
-unix-sock-unlink

main.go

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,40 @@ func arg_fail(msg string) {
6767
os.Exit(2)
6868
}
6969

70-
type CSVArg []string
71-
72-
func (a *CSVArg) Set(s string) error {
73-
*a = strings.Split(s, ",")
74-
return nil
70+
type CSVArg struct {
71+
values []string
7572
}
7673

7774
func (a *CSVArg) String() string {
78-
if a == nil {
79-
return "<nil>"
80-
}
81-
if *a == nil {
82-
return "<empty>"
75+
if len(a.values) == 0 {
76+
return ""
8377
}
84-
return strings.Join(*a, ",")
78+
buf := new(bytes.Buffer)
79+
wr := csv.NewWriter(buf)
80+
wr.Write(a.values)
81+
wr.Flush()
82+
return strings.TrimRight(buf.String(), "\n")
8583
}
8684

87-
func (a *CSVArg) Value() []string {
88-
return []string(*a)
85+
func (a *CSVArg) Set(line string) error {
86+
if line == "" {
87+
a.values = nil
88+
return nil
89+
}
90+
rd := csv.NewReader(strings.NewReader(line))
91+
rd.FieldsPerRecord = -1
92+
rd.TrimLeadingSpace = true
93+
rd.ReuseRecord = true
94+
values, err := rd.Read()
95+
if err == io.EOF {
96+
a.values = nil
97+
return nil
98+
}
99+
if err != nil {
100+
return fmt.Errorf("unable to parse comma-separated argument: %w", err)
101+
}
102+
a.values = values
103+
return nil
89104
}
90105

91106
type PrefixList []netip.Prefix
@@ -295,6 +310,7 @@ type CLIArgs struct {
295310
userIPHints bool
296311
minTLSVersion TLSVersionArg
297312
maxTLSVersion TLSVersionArg
313+
tlsALPNEnabled bool
298314
bwLimit uint64
299315
bwBurst int64
300316
bwBuckets uint
@@ -313,8 +329,8 @@ type CLIArgs struct {
313329
shutdownTimeout time.Duration
314330
}
315331

316-
func parse_args() CLIArgs {
317-
args := CLIArgs{
332+
func parse_args() *CLIArgs {
333+
args := &CLIArgs{
318334
minTLSVersion: TLSVersionArg(tls.VersionTLS12),
319335
maxTLSVersion: TLSVersionArg(tls.VersionTLS13),
320336
denyDstAddr: PrefixList{
@@ -420,6 +436,7 @@ func parse_args() CLIArgs {
420436
flag.BoolVar(&args.userIPHints, "user-ip-hints", false, "allow IP hints to be specified by user in X-Src-IP-Hints header")
421437
flag.Var(&args.minTLSVersion, "min-tls-version", "minimum TLS version accepted by server")
422438
flag.Var(&args.maxTLSVersion, "max-tls-version", "maximum TLS version accepted by server")
439+
flag.BoolVar(&args.tlsALPNEnabled, "tls-alpn-enabled", true, "enable application protocol negotiation with TLS ALPN extension")
423440
flag.Uint64Var(&args.bwLimit, "bw-limit", 0, "per-user bandwidth limit in bytes per second")
424441
flag.Int64Var(&args.bwBurst, "bw-limit-burst", 0, "allowed burst size for bandwidth limit, how many \"tokens\" can fit into leaky bucket")
425442
flag.UintVar(&args.bwBuckets, "bw-limit-buckets", 1024*1024, "number of buckets of bandwidth limit")
@@ -684,9 +701,7 @@ func run() int {
684701
}
685702

686703
if args.cert != "" {
687-
cfg, err1 := makeServerTLSConfig(args.cert, args.key, args.cafile,
688-
args.ciphers, args.curves,
689-
uint16(args.minTLSVersion), uint16(args.maxTLSVersion), !args.disableHTTP2)
704+
cfg, err1 := makeServerTLSConfig(args)
690705
if err1 != nil {
691706
mainLogger.Critical("TLS config construction failed: %v", err1)
692707
return 3
@@ -735,22 +750,24 @@ func run() int {
735750
Client: &acme.Client{DirectoryURL: args.autocertACME},
736751
Email: args.autocertEmail,
737752
}
738-
if args.autocertWhitelist.Value() != nil {
739-
m.HostPolicy = autocert.HostWhitelist(args.autocertWhitelist.Value()...)
753+
if args.autocertWhitelist.values != nil {
754+
m.HostPolicy = autocert.HostWhitelist(args.autocertWhitelist.values...)
740755
}
741756
if args.autocertHTTP != "" {
742757
go func() {
743-
log.Fatalf("HTTP-01 ACME challenge server stopped: %v",
758+
mainLogger.Critical("HTTP-01 ACME challenge server stopped: %v",
744759
http.ListenAndServe(args.autocertHTTP, m.HTTPHandler(nil)))
745760
}()
746761
}
747-
cfg := m.TLSConfig()
748-
cfg, err = updateServerTLSConfig(cfg, args.cafile, args.ciphers, args.curves,
749-
uint16(args.minTLSVersion), uint16(args.maxTLSVersion), !args.disableHTTP2)
762+
cfg, err := makeServerTLSConfig(args)
750763
if err != nil {
751764
mainLogger.Critical("TLS config construction failed: %v", err)
752765
return 3
753766
}
767+
cfg.GetCertificate = m.GetCertificate
768+
if len(cfg.NextProtos) > 0 {
769+
cfg.NextProtos = append(cfg.NextProtos, acme.ALPNProto)
770+
}
754771
listener = tls.NewListener(listener, cfg)
755772
}
756773
defer listener.Close()
@@ -861,66 +878,43 @@ func run() int {
861878
return 2
862879
}
863880

864-
func makeServerTLSConfig(certfile, keyfile, cafile, ciphers, curves string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) {
881+
func makeServerTLSConfig(args *CLIArgs) (*tls.Config, error) {
865882
cfg := tls.Config{
866-
MinVersion: minVer,
867-
MaxVersion: maxVer,
868-
}
869-
cert, err := tls.LoadX509KeyPair(certfile, keyfile)
870-
if err != nil {
871-
return nil, err
883+
MinVersion: uint16(args.minTLSVersion),
884+
MaxVersion: uint16(args.maxTLSVersion),
872885
}
873-
cfg.Certificates = []tls.Certificate{cert}
874-
if cafile != "" {
875-
roots, err := tlsutil.LoadCAfile(cafile)
886+
if args.cert != "" {
887+
cert, err := tls.LoadX509KeyPair(args.cert, args.key)
876888
if err != nil {
877889
return nil, err
878890
}
879-
cfg.ClientCAs = roots
880-
cfg.ClientAuth = tls.VerifyClientCertIfGiven
881-
}
882-
cfg.CipherSuites, err = tlsutil.ParseCipherList(ciphers)
883-
if err != nil {
884-
return nil, err
885-
}
886-
cfg.CurvePreferences, err = tlsutil.ParseCurveList(curves)
887-
if err != nil {
888-
return nil, err
889-
}
890-
if h2 {
891-
cfg.NextProtos = []string{"h2", "http/1.1"}
892-
} else {
893-
cfg.NextProtos = []string{"http/1.1"}
891+
cfg.Certificates = []tls.Certificate{cert}
894892
}
895-
return &cfg, nil
896-
}
897-
898-
func updateServerTLSConfig(cfg *tls.Config, cafile, ciphers, curves string, minVer, maxVer uint16, h2 bool) (*tls.Config, error) {
899-
if cafile != "" {
900-
roots, err := tlsutil.LoadCAfile(cafile)
893+
if args.cafile != "" {
894+
roots, err := tlsutil.LoadCAfile(args.cafile)
901895
if err != nil {
902896
return nil, err
903897
}
904898
cfg.ClientCAs = roots
905899
cfg.ClientAuth = tls.VerifyClientCertIfGiven
906900
}
907901
var err error
908-
cfg.CipherSuites, err = tlsutil.ParseCipherList(ciphers)
902+
cfg.CipherSuites, err = tlsutil.ParseCipherList(args.ciphers)
909903
if err != nil {
910904
return nil, err
911905
}
912-
cfg.CurvePreferences, err = tlsutil.ParseCurveList(curves)
906+
cfg.CurvePreferences, err = tlsutil.ParseCurveList(args.curves)
913907
if err != nil {
914908
return nil, err
915909
}
916-
if h2 {
917-
cfg.NextProtos = []string{"h2", "http/1.1", "acme-tls/1"}
918-
} else {
919-
cfg.NextProtos = []string{"http/1.1", "acme-tls/1"}
910+
if args.tlsALPNEnabled {
911+
if !args.disableHTTP2 {
912+
cfg.NextProtos = []string{"h2", "http/1.1"}
913+
} else {
914+
cfg.NextProtos = []string{"http/1.1"}
915+
}
920916
}
921-
cfg.MinVersion = minVer
922-
cfg.MaxVersion = maxVer
923-
return cfg, nil
917+
return &cfg, nil
924918
}
925919

926920
func list_ciphers() {

0 commit comments

Comments
 (0)