Skip to content

Commit 5915888

Browse files
committed
config: allow TLS data to be provided inline
This permits TLS data to be provided inline rather than by specifying a file on disk. Signed-off-by: Robert Fratto <[email protected]>
1 parent 2f04d2e commit 5915888

File tree

2 files changed

+149
-44
lines changed

2 files changed

+149
-44
lines changed

config/http_config.go

Lines changed: 147 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
579579
// No need for a RoundTripper that reloads the CA file automatically.
580580
return newRT(tlsConfig)
581581
}
582-
583-
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, cfg.TLSConfig.CertFile, cfg.TLSConfig.KeyFile, newRT)
582+
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.roundTripperSettings(), newRT)
584583
}
585584

586585
type authorizationCredentialsRoundTripper struct {
@@ -750,7 +749,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
750749
if len(rt.config.TLSConfig.CAFile) == 0 {
751750
t, _ = tlsTransport(tlsConfig)
752751
} else {
753-
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, rt.config.TLSConfig.CertFile, rt.config.TLSConfig.KeyFile, tlsTransport)
752+
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.roundTripperSettings(), tlsTransport)
754753
if err != nil {
755754
return nil, err
756755
}
@@ -817,6 +816,10 @@ func cloneRequest(r *http.Request) *http.Request {
817816

818817
// NewTLSConfig creates a new tls.Config from the given TLSConfig.
819818
func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
819+
if err := cfg.Validate(); err != nil {
820+
return nil, err
821+
}
822+
820823
tlsConfig := &tls.Config{
821824
InsecureSkipVerify: cfg.InsecureSkipVerify,
822825
MinVersion: uint16(cfg.MinVersion),
@@ -831,7 +834,11 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
831834

832835
// If a CA cert is provided then let's read it in so we can validate the
833836
// scrape target's certificate properly.
834-
if len(cfg.CAFile) > 0 {
837+
if len(cfg.CA) > 0 {
838+
if !updateRootCA(tlsConfig, []byte(cfg.CA)) {
839+
return nil, fmt.Errorf("unable to use inline CA cert")
840+
}
841+
} else if len(cfg.CAFile) > 0 {
835842
b, err := readCAFile(cfg.CAFile)
836843
if err != nil {
837844
return nil, err
@@ -844,12 +851,9 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
844851
if len(cfg.ServerName) > 0 {
845852
tlsConfig.ServerName = cfg.ServerName
846853
}
854+
847855
// If a client cert & key is provided then configure TLS config accordingly.
848-
if len(cfg.CertFile) > 0 && len(cfg.KeyFile) == 0 {
849-
return nil, fmt.Errorf("client cert file %q specified without client key file", cfg.CertFile)
850-
} else if len(cfg.KeyFile) > 0 && len(cfg.CertFile) == 0 {
851-
return nil, fmt.Errorf("client key file %q specified without client cert file", cfg.KeyFile)
852-
} else if len(cfg.CertFile) > 0 && len(cfg.KeyFile) > 0 {
856+
if cfg.usingClientCert() && cfg.usingClientKey() {
853857
// Verify that client cert and key are valid.
854858
if _, err := cfg.getClientCertificate(nil); err != nil {
855859
return nil, err
@@ -862,6 +866,12 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
862866

863867
// TLSConfig configures the options for TLS connections.
864868
type TLSConfig struct {
869+
// Text of the CA cert to use for the targets.
870+
CA string `yaml:"ca,omitempty" json:"ca,omitempty"`
871+
// Text of the client cert file for the targets.
872+
Cert string `yaml:"cert,omitempty" json:"cert,omitempty"`
873+
// Text of the client key file for the targets.
874+
Key Secret `yaml:"key,omitempty" json:"key,omitempty"`
865875
// The CA cert to use for the targets.
866876
CAFile string `yaml:"ca_file,omitempty" json:"ca_file,omitempty"`
867877
// The client cert file for the targets.
@@ -891,7 +901,52 @@ func (c *TLSConfig) SetDirectory(dir string) {
891901
// UnmarshalYAML implements the yaml.Unmarshaler interface.
892902
func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
893903
type plain TLSConfig
894-
return unmarshal((*plain)(c))
904+
if err := unmarshal((*plain)(c)); err != nil {
905+
return err
906+
}
907+
return c.Validate()
908+
}
909+
910+
// Validate validates the TLSConfig to check that only one of the inlined or
911+
// file-based fields for the TLS CA, client certificate, and client key are
912+
// used.
913+
func (c *TLSConfig) Validate() error {
914+
if len(c.CA) > 0 && len(c.CAFile) > 0 {
915+
return fmt.Errorf("at most one of ca and ca_file must be configured")
916+
}
917+
if len(c.Cert) > 0 && len(c.CertFile) > 0 {
918+
return fmt.Errorf("at most one of cert and cert_file must be configured")
919+
}
920+
if len(c.Key) > 0 && len(c.KeyFile) > 0 {
921+
return fmt.Errorf("at most one of key and key_file must be configured")
922+
}
923+
924+
if c.usingClientCert() && !c.usingClientKey() {
925+
return fmt.Errorf("exactly one of key or key_file must be configured when a client certificate is configured")
926+
} else if c.usingClientKey() && !c.usingClientCert() {
927+
return fmt.Errorf("exactly one of cert or cert_file must be configured when a client key is configured")
928+
}
929+
930+
return nil
931+
}
932+
933+
func (c *TLSConfig) usingClientCert() bool {
934+
return len(c.Cert) > 0 || len(c.CertFile) > 0
935+
}
936+
937+
func (c *TLSConfig) usingClientKey() bool {
938+
return len(c.Key) > 0 || len(c.KeyFile) > 0
939+
}
940+
941+
func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings {
942+
return TLSRoundTripperSettings{
943+
CA: c.CA,
944+
CAFile: c.CAFile,
945+
Cert: c.Cert,
946+
CertFile: c.CertFile,
947+
Key: string(c.Key),
948+
KeyFile: c.KeyFile,
949+
}
895950
}
896951

897952
// readCertAndKey reads the cert and key files from the disk.
@@ -911,9 +966,27 @@ func readCertAndKey(certFile, keyFile string) ([]byte, []byte, error) {
911966

912967
// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
913968
func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
914-
certData, keyData, err := readCertAndKey(c.CertFile, c.KeyFile)
915-
if err != nil {
916-
return nil, fmt.Errorf("unable to read specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
969+
var (
970+
certData, keyData []byte
971+
err error
972+
)
973+
974+
if c.CertFile != "" {
975+
certData, err = os.ReadFile(c.CertFile)
976+
if err != nil {
977+
return nil, fmt.Errorf("unable to read specified client cert (%s): %s", c.CertFile, err)
978+
}
979+
} else {
980+
certData = []byte(c.Cert)
981+
}
982+
983+
if c.KeyFile != "" {
984+
keyData, err = os.ReadFile(c.KeyFile)
985+
if err != nil {
986+
return nil, fmt.Errorf("unable to read specified client key (%s): %s", c.KeyFile, err)
987+
}
988+
} else {
989+
keyData = []byte(c.Key)
917990
}
918991

919992
cert, err := tls.X509KeyPair(certData, keyData)
@@ -946,30 +1019,32 @@ func updateRootCA(cfg *tls.Config, b []byte) bool {
9461019
// tlsRoundTripper is a RoundTripper that updates automatically its TLS
9471020
// configuration whenever the content of the CA file changes.
9481021
type tlsRoundTripper struct {
949-
caFile string
950-
certFile string
951-
keyFile string
1022+
settings TLSRoundTripperSettings
9521023

9531024
// newRT returns a new RoundTripper.
9541025
newRT func(*tls.Config) (http.RoundTripper, error)
9551026

9561027
mtx sync.RWMutex
9571028
rt http.RoundTripper
958-
hashCAFile []byte
959-
hashCertFile []byte
960-
hashKeyFile []byte
1029+
hashCAData []byte
1030+
hashCertData []byte
1031+
hashKeyData []byte
9611032
tlsConfig *tls.Config
9621033
}
9631034

1035+
type TLSRoundTripperSettings struct {
1036+
CA, CAFile string
1037+
Cert, CertFile string
1038+
Key, KeyFile string
1039+
}
1040+
9641041
func NewTLSRoundTripper(
9651042
cfg *tls.Config,
966-
caFile, certFile, keyFile string,
1043+
settings TLSRoundTripperSettings,
9671044
newRT func(*tls.Config) (http.RoundTripper, error),
9681045
) (http.RoundTripper, error) {
9691046
t := &tlsRoundTripper{
970-
caFile: caFile,
971-
certFile: certFile,
972-
keyFile: keyFile,
1047+
settings: settings,
9731048
newRT: newRT,
9741049
tlsConfig: cfg,
9751050
}
@@ -979,44 +1054,74 @@ func NewTLSRoundTripper(
9791054
return nil, err
9801055
}
9811056
t.rt = rt
982-
_, t.hashCAFile, t.hashCertFile, t.hashKeyFile, err = t.getTLSFilesWithHash()
1057+
_, t.hashCAData, t.hashCertData, t.hashKeyData, err = t.getTLSDataWithHash()
9831058
if err != nil {
9841059
return nil, err
9851060
}
9861061

9871062
return t, nil
9881063
}
9891064

990-
func (t *tlsRoundTripper) getTLSFilesWithHash() ([]byte, []byte, []byte, []byte, error) {
991-
b1, err := readCAFile(t.caFile)
992-
if err != nil {
993-
return nil, nil, nil, nil, err
1065+
func (t *tlsRoundTripper) getTLSDataWithHash() ([]byte, []byte, []byte, []byte, error) {
1066+
var (
1067+
caBytes, certBytes, keyBytes []byte
1068+
1069+
err error
1070+
)
1071+
1072+
if t.settings.CAFile != "" {
1073+
caBytes, err = os.ReadFile(t.settings.CAFile)
1074+
if err != nil {
1075+
return nil, nil, nil, nil, err
1076+
}
1077+
} else if t.settings.CA != "" {
1078+
caBytes = []byte(t.settings.CA)
1079+
}
1080+
1081+
if t.settings.CertFile != "" {
1082+
certBytes, err = os.ReadFile(t.settings.CertFile)
1083+
if err != nil {
1084+
return nil, nil, nil, nil, err
1085+
}
1086+
} else if t.settings.Cert != "" {
1087+
certBytes = []byte(t.settings.Cert)
9941088
}
995-
h1 := sha256.Sum256(b1)
9961089

997-
var h2, h3 [32]byte
998-
if t.certFile != "" {
999-
b2, b3, err := readCertAndKey(t.certFile, t.keyFile)
1090+
if t.settings.KeyFile != "" {
1091+
keyBytes, err = os.ReadFile(t.settings.KeyFile)
10001092
if err != nil {
10011093
return nil, nil, nil, nil, err
10021094
}
1003-
h2, h3 = sha256.Sum256(b2), sha256.Sum256(b3)
1095+
} else if t.settings.Key != "" {
1096+
keyBytes = []byte(t.settings.Key)
1097+
}
1098+
1099+
var caHash, certHash, keyHash [32]byte
1100+
1101+
if len(caBytes) > 0 {
1102+
caHash = sha256.Sum256(caBytes)
1103+
}
1104+
if len(certBytes) > 0 {
1105+
certHash = sha256.Sum256(certBytes)
1106+
}
1107+
if len(keyBytes) > 0 {
1108+
keyHash = sha256.Sum256(keyBytes)
10041109
}
10051110

1006-
return b1, h1[:], h2[:], h3[:], nil
1111+
return caBytes, caHash[:], certHash[:], keyHash[:], nil
10071112
}
10081113

10091114
// RoundTrip implements the http.RoundTrip interface.
10101115
func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
1011-
caData, caHash, certHash, keyHash, err := t.getTLSFilesWithHash()
1116+
caData, caHash, certHash, keyHash, err := t.getTLSDataWithHash()
10121117
if err != nil {
10131118
return nil, err
10141119
}
10151120

10161121
t.mtx.RLock()
1017-
equal := bytes.Equal(caHash[:], t.hashCAFile) &&
1018-
bytes.Equal(certHash[:], t.hashCertFile) &&
1019-
bytes.Equal(keyHash[:], t.hashKeyFile)
1122+
equal := bytes.Equal(caHash[:], t.hashCAData) &&
1123+
bytes.Equal(certHash[:], t.hashCertData) &&
1124+
bytes.Equal(keyHash[:], t.hashKeyData)
10201125
rt := t.rt
10211126
t.mtx.RUnlock()
10221127
if equal {
@@ -1029,7 +1134,7 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
10291134
// using GetClientCertificate.
10301135
tlsConfig := t.tlsConfig.Clone()
10311136
if !updateRootCA(tlsConfig, caData) {
1032-
return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile)
1137+
return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CAFile)
10331138
}
10341139
rt, err = t.newRT(tlsConfig)
10351140
if err != nil {
@@ -1039,9 +1144,9 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
10391144

10401145
t.mtx.Lock()
10411146
t.rt = rt
1042-
t.hashCAFile = caHash[:]
1043-
t.hashCertFile = certHash[:]
1044-
t.hashKeyFile = keyHash[:]
1147+
t.hashCAData = caHash[:]
1148+
t.hashCertData = certHash[:]
1149+
t.hashKeyData = keyHash[:]
10451150
t.mtx.Unlock()
10461151

10471152
return rt.RoundTrip(req)

config/http_config_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,15 +774,15 @@ func TestTLSConfigInvalidCA(t *testing.T) {
774774
KeyFile: ClientKeyNoPassPath,
775775
ServerName: "",
776776
InsecureSkipVerify: false},
777-
errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
777+
errorMessage: fmt.Sprintf("unable to read specified client cert (%s):", MissingCert),
778778
}, {
779779
configTLSConfig: TLSConfig{
780780
CAFile: "",
781781
CertFile: ClientCertificatePath,
782782
KeyFile: MissingKey,
783783
ServerName: "",
784784
InsecureSkipVerify: false},
785-
errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
785+
errorMessage: fmt.Sprintf("unable to read specified client key (%s):", MissingKey),
786786
},
787787
}
788788

0 commit comments

Comments
 (0)