Skip to content

Commit c24d1d2

Browse files
authored
JWKSCache: add option to set CA certificate to trust (dapr#81)
This is helpful when the JWKS is located on a HTTPS endpoint and the certificate is signed by a custom CA. Signed-off-by: ItalyPaleAle <[email protected]>
1 parent 77f7f03 commit c24d1d2

File tree

2 files changed

+69
-4
lines changed

2 files changed

+69
-4
lines changed

jwkscache/cache.go

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package jwkscache
2222
import (
2323
"context"
2424
"crypto/tls"
25+
"crypto/x509"
2526
"encoding/base64"
2627
"errors"
2728
"fmt"
@@ -38,6 +39,7 @@ import (
3839

3940
"github.com/dapr/kit/fswatcher"
4041
"github.com/dapr/kit/logger"
42+
"github.com/dapr/kit/utils"
4143
)
4244

4345
const (
@@ -49,11 +51,11 @@ const (
4951

5052
// JWKSCache is a cache of JWKS objects.
5153
// It fetches a JWKS object from a file on disk, a URL, or from a value passed as-is.
52-
// TODO: Move this to dapr/kit and use it for the JWKS crypto component too
5354
type JWKSCache struct {
5455
location string
5556
requestTimeout time.Duration
5657
minRefreshInterval time.Duration
58+
caCertificate string
5759

5860
jwks jwk.Set
5961
logger logger.Logger
@@ -113,6 +115,12 @@ func (c *JWKSCache) SetMinRefreshInterval(minRefreshInterval time.Duration) {
113115
c.minRefreshInterval = minRefreshInterval
114116
}
115117

118+
// SetCACertificate sets the CA certificate to trust.
119+
// Can be a path to a local file or an actual, PEM-encoded certificate
120+
func (c *JWKSCache) SetCACertificate(caCertificate string) {
121+
c.caCertificate = caCertificate
122+
}
123+
116124
// SetHTTPClient sets the HTTP client object to use.
117125
func (c *JWKSCache) SetHTTPClient(client *http.Client) {
118126
c.client = client
@@ -184,12 +192,28 @@ func (c *JWKSCache) initJWKSFromURL(ctx context.Context, url string) error {
184192

185193
// We also need to create a custom HTTP client (if we don't have one already) because otherwise there's no timeout.
186194
if c.client == nil {
195+
tlsConfig := &tls.Config{
196+
MinVersion: tls.VersionTLS12,
197+
}
198+
199+
// Load CA certificates if we have one
200+
if c.caCertificate != "" {
201+
caCert, err := utils.GetPEM(c.caCertificate)
202+
if err != nil {
203+
return fmt.Errorf("failed to load CA certificate: %w", err)
204+
}
205+
206+
caCertPool := x509.NewCertPool()
207+
if !caCertPool.AppendCertsFromPEM(caCert) {
208+
return errors.New("failed to add root certificate to certificate pool")
209+
}
210+
tlsConfig.RootCAs = caCertPool
211+
}
212+
187213
c.client = &http.Client{
188214
Timeout: c.requestTimeout,
189215
Transport: &http.Transport{
190-
TLSClientConfig: &tls.Config{
191-
MinVersion: tls.VersionTLS12,
192-
},
216+
TLSClientConfig: tlsConfig,
193217
},
194218
}
195219
}

utils/pem.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
Copyright 2023 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package utils
15+
16+
import (
17+
"encoding/pem"
18+
"fmt"
19+
"os"
20+
)
21+
22+
// GetPEM loads a PEM-encoded file (certificate or key).
23+
func GetPEM(val string) ([]byte, error) {
24+
// If val is already a PEM-encoded string, return it as-is
25+
if IsValidPEM(val) {
26+
return []byte(val), nil
27+
}
28+
29+
// Assume it's a file
30+
pemBytes, err := os.ReadFile(val)
31+
if err != nil {
32+
return nil, fmt.Errorf("value is neither a valid file path or nor a valid PEM-encoded string: %w", err)
33+
}
34+
return pemBytes, nil
35+
}
36+
37+
// IsValidPEM validates the provided input has PEM formatted block.
38+
func IsValidPEM(val string) bool {
39+
block, _ := pem.Decode([]byte(val))
40+
return block != nil
41+
}

0 commit comments

Comments
 (0)