Skip to content

Commit 97cf0a9

Browse files
committed
rfq: allow os, custom certificates
Adds both 'TrustSystemRootCAs' and 'CustomCertificates' to the rfq TLSConfig. The former indicates whether or not to trust the operating system's root CA list; the latter allows additional certificates (CA or self-signed) to be trusted. Also adds a basic unit test skeleton.
1 parent 166bc8c commit 97cf0a9

File tree

3 files changed

+140
-12
lines changed

3 files changed

+140
-12
lines changed

rfq/oracle.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -222,27 +222,21 @@ func NewRpcPriceOracle(addrStr string, tlsConfig *TLSConfig) (*RpcPriceOracle,
222222
return nil, err
223223
}
224224

225-
// Connect to the RPC server.
226-
dialOpts, err := serverDialOpts()
225+
// Create transport credentials and dial options from the supplied TLS
226+
// config.
227+
transportCredentials, err := configureTransportCredentials(tlsConfig)
227228
if err != nil {
228229
return nil, err
229230
}
230231

231-
// Determine whether we should skip certificate verification.
232-
dialInsecure := tlsConfig.InsecureSkipVerify
233-
234-
// Allow connecting to a non-TLS (h2c, http over cleartext) gRPC server,
235-
// should be used for testing only.
236-
if dialInsecure {
237-
dialOpts, err = insecureServerDialOpts()
238-
if err != nil {
239-
return nil, err
240-
}
232+
dialOpts := []grpc.DialOption{
233+
grpc.WithTransportCredentials(transportCredentials),
241234
}
242235

243236
// Formulate the server address dial string.
244237
serverAddr := fmt.Sprintf("%s:%s", addr.Hostname(), addr.Port())
245238

239+
// Connect to the RPC server.
246240
conn, err := grpc.Dial(serverAddr, dialOpts...)
247241
if err != nil {
248242
return nil, err

rfq/tls.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
package rfq
22

3+
import (
4+
"crypto/x509"
5+
6+
"google.golang.org/grpc/credentials"
7+
"google.golang.org/grpc/credentials/insecure"
8+
)
9+
310
// TLSConfig represents TLS configuration options for oracle connections.
411
type TLSConfig struct {
512
// InsecureSkipVerify disables certificate verification.
613
InsecureSkipVerify bool
14+
15+
// TrustSystemRootCAs indicates whether or not to use the operating
16+
// system's root certificate authority list.
17+
TrustSystemRootCAs bool
18+
19+
// CustomCertificates contains PEM data for additional root CA and
20+
// self-signed certificates to trust.
21+
CustomCertificates []byte
722
}
823

924
// DefaultTLSConfig returns a default TLS configuration.
@@ -12,3 +27,37 @@ func DefaultTLSConfig() *TLSConfig {
1227
InsecureSkipVerify: true,
1328
}
1429
}
30+
31+
// configureTransportCredentials configures the TLS transport credentials to
32+
// be used for RPC connections.
33+
func configureTransportCredentials(
34+
config *TLSConfig) (credentials.TransportCredentials, error) {
35+
36+
// If we're to skip certificate verification, then just return
37+
// insecure credentials here.
38+
if config.InsecureSkipVerify {
39+
return insecure.NewCredentials(), nil
40+
}
41+
42+
// Initialize the certificate pool.
43+
certPool, err := constructCertPool(config.TrustSystemRootCAs)
44+
if err != nil {
45+
return nil, err
46+
}
47+
48+
// If we have any custom certificates, add them to the certificate
49+
// pool.
50+
certPool.AppendCertsFromPEM(config.CustomCertificates)
51+
52+
// Return the constructed transport credentials.
53+
return credentials.NewClientTLSFromCert(certPool, ""), nil
54+
}
55+
56+
// constructCertPool is a helper for constructing an initial certificate pool,
57+
// depending on whether or not we should trust the system root CA list.
58+
func constructCertPool(trustSystem bool) (*x509.CertPool, error) {
59+
if trustSystem {
60+
return x509.SystemCertPool()
61+
}
62+
return x509.NewCertPool(), nil
63+
}

rfq/tls_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package rfq
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
// Test certificate data - a valid self-signed certificate for testing
10+
const validTestCertPEM = `-----BEGIN CERTIFICATE-----
11+
MIICmjCCAYICCQCuu1gzY+BBKjANBgkqhkiG9w0BAQsFADAPMQ0wCwYDVQQDDAR0
12+
ZXN0MB4XDTI1MDgyODEwNDA1NVoXDTI1MDgyOTEwNDA1NVowDzENMAsGA1UEAwwE
13+
dGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALTWCm8l3d9nE2QK
14+
TK8HJ36ftO8pK3//nb8Nj/p97FrPFSgzdgL1ZNJs4gP5/ZsU+iE6VeKhalHoSf6/
15+
IMLe3ATTL0rWA1M6z7cw6ll8VS8NQMaMSFWNomncsxyoJAQde++SC5f1RwQJBD/0
16+
gGB4bJIIqUHtT12m23GLX48d6JGEEi5kEQtk91S/QGnHtglzZ8CQOogDBzDhSHu2
17+
jj4mKYDgkXcyAqN7DoDzoEcrpeAaeAwem8k1sFBeTtrqT1ot7Ey5KG+RUyJbdKGt
18+
5adJiwH782NgsSnISQ2X7Sct6Uu0JzHKx9JzyABsA05tf3cNJkLhh1Is9edYI2e9
19+
m0dqedECAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAQOCs/7xZVPjabbhdv30mUJMG
20+
lddi2A+R/5IRXW1MKnpemwiv4ZWYQ9PMTmuR7kqaF7AGLkvx+5sp2evUJN4x7vHP
21+
ao6wihbdh+vBkrobE+Y9dE7nbkvMQSNi1sXzDnfZB9LqY9Huun2soUwBQNCMPVMa
22+
Wo7g6udwyA48doEVJMjThFLPcW7xmsy6Ldew682m1kD8/ag+9qihX1IJyiqiEjha
23+
3uT4CT+zEg0RJorEJKbR38fE4Uhx1wZO4zvjEg6qZeW/I4lw+UzSY5xV7lJ1EQvf
24+
BcoNuBHB65RxQM5fpA7hkEFm1bxBoowGX2hx6VCCeBBwREISRfgvkUxZahUXNg==
25+
-----END CERTIFICATE-----`
26+
27+
// Invalid PEM data for testing failure cases
28+
const invalidTestCertPEM = `-----BEGIN CERTIFICATE-----
29+
This is not a valid certificate
30+
-----END CERTIFICATE-----`
31+
32+
// DefaultTLSConfig returns a default TLS configuration for testing.
33+
func DefaultTLSConfig() *TLSConfig {
34+
return &TLSConfig{
35+
InsecureSkipVerify: true,
36+
}
37+
}
38+
39+
// TestConfigureTransportCredentials_InsecureSkipVerify tests the function
40+
// when InsecureSkipVerify is true.
41+
func TestConfigureTransportCredentials_InsecureSkipVerify(t *testing.T) {
42+
config := &TLSConfig{
43+
InsecureSkipVerify: true,
44+
}
45+
46+
creds, err := configureTransportCredentials(config)
47+
48+
require.NoError(t, err)
49+
require.NotNil(t, creds)
50+
51+
// Verify that we got insecure credentials by checking the type
52+
require.Equal(t, "insecure", creds.Info().SecurityProtocol)
53+
}
54+
55+
// TestConfigureTransportCredentials_ValidCustomCertificates tests the
56+
// function when valid custom certificates are provided.
57+
func TestConfigureTransportCredentials_ValidCustomCertificates(t *testing.T) {
58+
config := &TLSConfig{
59+
InsecureSkipVerify: false,
60+
CustomCertificates: []byte(validTestCertPEM),
61+
}
62+
63+
creds, err := configureTransportCredentials(config)
64+
65+
require.NoError(t, err)
66+
require.NotNil(t, creds)
67+
68+
// Verify that we got TLS credentials (not insecure)
69+
require.Equal(t, "tls", creds.Info().SecurityProtocol)
70+
}
71+
72+
// TestConfigureTransportCredentials_NoCredentialsConfigured tests the
73+
// function when no credentials are configured.
74+
func TestConfigureTransportCredentials_NoCredentialsConfigured(t *testing.T) {
75+
config := &TLSConfig{
76+
InsecureSkipVerify: false,
77+
CustomCertificates: nil,
78+
}
79+
80+
creds, err := configureTransportCredentials(config)
81+
82+
require.NoError(t, err)
83+
require.NotNil(t, creds)
84+
require.Equal(t, "tls", creds.Info().SecurityProtocol)
85+
}

0 commit comments

Comments
 (0)