Skip to content

Commit 26027ab

Browse files
committed
automatically create TLS certs if they don't exist
1 parent 9a385d8 commit 26027ab

File tree

3 files changed

+40
-18
lines changed

3 files changed

+40
-18
lines changed

rserver.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"bufio"
1212
"github.com/hashicorp/yamux"
13+
"github.com/kost/tty2web/tlshelp"
1314
"strconv"
1415
"strings"
1516
"time"
@@ -29,7 +30,7 @@ func listenForAgents(verbose bool, tlslisten bool, address string, clients strin
2930
if tlslisten {
3031
log.Printf("Listening for agents on %s using TLS", address)
3132
if certificate == "" {
32-
cer, err = getRandomTLS(2048)
33+
cer, err = tlshelp.GetRandomTLS(2048)
3334
log.Println("No TLS certificate. Generated random one.")
3435
} else {
3536
cer, err = tls.LoadX509KeyPair(certificate+".crt", certificate+".key")

server/server.go

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/kost/tty2web/pkg/homedir"
2424
"github.com/kost/tty2web/pkg/randomstring"
2525
"github.com/kost/tty2web/webtty"
26+
"github.com/kost/tty2web/tlshelp"
2627
"github.com/kost/httpexecute"
2728
"github.com/kost/regeorgo"
2829
)
@@ -126,6 +127,23 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
126127

127128
srvErr := make(chan error, 1)
128129

130+
if server.options.EnableTLS {
131+
crtFile := homedir.Expand(server.options.TLSCrtFile)
132+
keyFile := homedir.Expand(server.options.TLSKeyFile)
133+
log.Printf("TLS crt file: " + crtFile)
134+
log.Printf("TLS key file: " + keyFile)
135+
cer, err := tls.LoadX509KeyPair(crtFile,keyFile)
136+
if err != nil {
137+
log.Printf("Error loading TLS key and crt file %s and %s: %v. Generating random one!", crtFile, keyFile, err)
138+
139+
cer, err = tlshelp.GetRandomTLS(2048)
140+
if err != nil {
141+
return errors.Wrapf(err, "error generating and failed to load tls cert and key `%s` and `%s`", crtFile, keyFile)
142+
}
143+
}
144+
config := &tls.Config{Certificates: []tls.Certificate{cer}}
145+
srv.TLSConfig=config
146+
}
129147
if server.options.Dns != "" {
130148
go func() {
131149
session, err = DnsConnectSocks(server.options.Dns, server.options.DnsKey, server.options.DnsDelay)
@@ -134,7 +152,11 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
134152
srvErr <- err
135153
return
136154
}
137-
err = srv.Serve(session)
155+
if server.options.EnableTLS {
156+
err = srv.ServeTLS(session, "", "")
157+
} else {
158+
err = srv.Serve(session)
159+
}
138160
if err != nil {
139161
srvErr <- err
140162
}
@@ -160,12 +182,7 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
160182
}
161183
go func() {
162184
if server.options.EnableTLS {
163-
crtFile := homedir.Expand(server.options.TLSCrtFile)
164-
keyFile := homedir.Expand(server.options.TLSKeyFile)
165-
log.Printf("TLS crt file: " + crtFile)
166-
log.Printf("TLS key file: " + keyFile)
167-
168-
err = srv.ServeTLS(listener, crtFile, keyFile)
185+
err = srv.ServeTLS(listener, "", "")
169186
} else {
170187
err = srv.Serve(listener)
171188
}
@@ -181,7 +198,11 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
181198
srvErr <- err
182199
return
183200
}
184-
err = srv.Serve(session)
201+
if server.options.EnableTLS {
202+
err = srv.ServeTLS(session, "", "")
203+
} else {
204+
err = srv.Serve(session)
205+
}
185206
if err != nil {
186207
srvErr <- err
187208
}

tlshelp.go renamed to tlshelp/tlshelp.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package main
1+
package tlshelp
22

33
import (
44
"crypto/rand"
@@ -49,7 +49,7 @@ func RandBigInt(max *big.Int) *big.Int {
4949
return r
5050
}
5151

52-
func genPair(keysize int) (cacert []byte, cakey []byte, cert []byte, certkey []byte) {
52+
func GenPair(keysize int) (cacert []byte, cakey []byte, cert []byte, certkey []byte) {
5353

5454
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
5555

@@ -105,7 +105,7 @@ func genPair(keysize int) (cacert []byte, cakey []byte, cert []byte, certkey []b
105105

106106
}
107107

108-
func verifyCert(cacert []byte, cert []byte) bool {
108+
func VerifyCert(cacert []byte, cert []byte) bool {
109109
caBin, _ := x509.ParseCertificate(cacert)
110110
cert2Bin, _ := x509.ParseCertificate(cert)
111111
err3 := cert2Bin.CheckSignatureFrom(caBin)
@@ -115,7 +115,7 @@ func verifyCert(cacert []byte, cert []byte) bool {
115115
return true
116116
}
117117

118-
func getPEMs(cert []byte, key []byte) (pemcert []byte, pemkey []byte) {
118+
func GetPEMs(cert []byte, key []byte) (pemcert []byte, pemkey []byte) {
119119
certPem := pem.EncodeToMemory(&pem.Block{
120120
Type: "CERTIFICATE",
121121
Bytes: cert,
@@ -129,17 +129,17 @@ func getPEMs(cert []byte, key []byte) (pemcert []byte, pemkey []byte) {
129129
return certPem, keyPem
130130
}
131131

132-
func getTLSPair(certPem []byte, keyPem []byte) (tls.Certificate, error) {
132+
func GetTLSPair(certPem []byte, keyPem []byte) (tls.Certificate, error) {
133133
tlspair, errt := tls.X509KeyPair(certPem, keyPem)
134134
if errt != nil {
135135
return tlspair, errt
136136
}
137137
return tlspair, nil
138138
}
139139

140-
func getRandomTLS(keysize int) (tls.Certificate, error) {
141-
_, _, cert, certkey := genPair(keysize)
142-
certPem, keyPem := getPEMs(cert, certkey)
143-
tlspair, err := getTLSPair(certPem, keyPem)
140+
func GetRandomTLS(keysize int) (tls.Certificate, error) {
141+
_, _, cert, certkey := GenPair(keysize)
142+
certPem, keyPem := GetPEMs(cert, certkey)
143+
tlspair, err := GetTLSPair(certPem, keyPem)
144144
return tlspair, err
145145
}

0 commit comments

Comments
 (0)