Skip to content

Commit 0c7f5ea

Browse files
committed
GoogleID libs
1 parent df1d968 commit 0c7f5ea

File tree

9 files changed

+789
-690
lines changed

9 files changed

+789
-690
lines changed

cmd/plugin-googleid-info/main.go

Lines changed: 8 additions & 239 deletions
Original file line numberDiff line numberDiff line change
@@ -1,174 +1,15 @@
11
package main
22

33
import (
4-
"context"
5-
"crypto/rsa"
6-
"errors"
74
"flag"
85
"fmt"
9-
"github.com/cenkalti/backoff"
10-
"github.com/grepplabs/kafka-proxy/pkg/apis"
11-
"github.com/grepplabs/kafka-proxy/pkg/libs/googleid"
6+
"github.com/grepplabs/kafka-proxy/pkg/libs/googleid-info"
127
"github.com/grepplabs/kafka-proxy/plugin/gateway-server/shared"
138
"github.com/hashicorp/go-plugin"
149
"github.com/sirupsen/logrus"
15-
"golang.org/x/oauth2/jws"
1610
"os"
17-
"regexp"
18-
"sort"
19-
"sync"
20-
"time"
2111
)
2212

23-
const (
24-
StatusOK = 0
25-
StatusEmptyToken = 1
26-
StatusParseJWTFailed = 2
27-
StatusNoIssueTimeInToken = 3
28-
StatusNoExpirationTimeInToken = 4
29-
StatusPublicKeyNotFound = 5
30-
StatusWrongIssuer = 6
31-
StatusWrongSignature = 7
32-
StatusTokenTooEarly = 8
33-
StatusTokenExpired = 9
34-
StatusWrongAudience = 10
35-
StatusWrongEmail = 11
36-
)
37-
38-
var (
39-
clockSkew = 1 * time.Minute
40-
nowFn = time.Now
41-
42-
issuers = map[string]struct{}{
43-
"accounts.google.com": {},
44-
"https://accounts.google.com": {},
45-
}
46-
)
47-
48-
type TokenInfo struct {
49-
timeout time.Duration
50-
audience map[string]struct{}
51-
emailRegex []*regexp.Regexp
52-
53-
publicKeys map[string]*rsa.PublicKey
54-
l sync.RWMutex
55-
}
56-
57-
func (p *TokenInfo) getPublicKey(kid string) *rsa.PublicKey {
58-
p.l.RLock()
59-
defer p.l.RUnlock()
60-
61-
return p.publicKeys[kid]
62-
}
63-
64-
func (p *TokenInfo) getPublicKeyIDs() []string {
65-
p.l.RLock()
66-
defer p.l.RUnlock()
67-
kids := make([]string, 0)
68-
for kid := range p.publicKeys {
69-
kids = append(kids, kid)
70-
}
71-
return kids
72-
}
73-
74-
func (p *TokenInfo) setPublicKeys(publicKeys map[string]*rsa.PublicKey) {
75-
p.l.Lock()
76-
defer p.l.Unlock()
77-
78-
p.publicKeys = publicKeys
79-
}
80-
81-
func (p *TokenInfo) refreshCerts() error {
82-
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
83-
defer cancel()
84-
85-
certs, err := googleid.GetCerts(ctx)
86-
if err != nil {
87-
return err
88-
}
89-
if len(certs.Keys) == 0 {
90-
return errors.New("certs keys must not be empty")
91-
}
92-
93-
publicKeys := make(map[string]*rsa.PublicKey)
94-
for _, key := range certs.Keys {
95-
publicKey, err := key.GetPublicKey()
96-
if err != nil {
97-
return fmt.Errorf("cannot parse public key: %v", key.Kid)
98-
}
99-
publicKeys[key.Kid] = publicKey
100-
}
101-
102-
p.setPublicKeys(publicKeys)
103-
104-
return nil
105-
}
106-
107-
func (p *TokenInfo) VerifyToken(parent context.Context, request apis.VerifyRequest) (apis.VerifyResponse, error) {
108-
// logrus.Infof("VerifyToken: %s", request.Token)
109-
if request.Token == "" {
110-
return getVerifyResponseResponse(StatusEmptyToken)
111-
}
112-
113-
token, err := googleid.ParseJWT(request.Token)
114-
if err != nil {
115-
return getVerifyResponseResponse(StatusParseJWTFailed)
116-
}
117-
if _, ok := issuers[token.ClaimSet.Iss]; !ok {
118-
return getVerifyResponseResponse(StatusWrongIssuer)
119-
}
120-
if token.ClaimSet.Iat < 1 {
121-
return getVerifyResponseResponse(StatusNoIssueTimeInToken)
122-
}
123-
if token.ClaimSet.Exp < 1 {
124-
return getVerifyResponseResponse(StatusNoExpirationTimeInToken)
125-
}
126-
127-
earliest := token.ClaimSet.Iat - int64(clockSkew.Seconds())
128-
latest := token.ClaimSet.Exp + int64(clockSkew.Seconds())
129-
unix := nowFn().Unix()
130-
131-
if unix < earliest {
132-
return getVerifyResponseResponse(StatusTokenTooEarly)
133-
}
134-
if unix > latest {
135-
return getVerifyResponseResponse(StatusTokenExpired)
136-
}
137-
138-
if len(p.audience) != 0 {
139-
if _, ok := p.audience[token.ClaimSet.Aud]; !ok {
140-
return getVerifyResponseResponse(StatusWrongAudience)
141-
}
142-
}
143-
if !p.checkEmail(token.ClaimSet.Email) {
144-
return getVerifyResponseResponse(StatusWrongEmail)
145-
}
146-
147-
publicKey := p.getPublicKey(token.Header.KeyID)
148-
if publicKey == nil {
149-
return getVerifyResponseResponse(StatusPublicKeyNotFound)
150-
}
151-
err = jws.Verify(request.Token, publicKey)
152-
if err != nil {
153-
return getVerifyResponseResponse(StatusWrongSignature)
154-
}
155-
return apis.VerifyResponse{Success: true}, nil
156-
}
157-
158-
func (p *TokenInfo) checkEmail(email string) bool {
159-
for _, re := range p.emailRegex {
160-
if re.MatchString(email) {
161-
return true
162-
}
163-
}
164-
return false
165-
}
166-
167-
func getVerifyResponseResponse(status int) (apis.VerifyResponse, error) {
168-
success := status == StatusOK
169-
return apis.VerifyResponse{Success: success, Status: int32(status)}, nil
170-
}
171-
17213
func (f *pluginMeta) flagSet() *flag.FlagSet {
17314
fs := flag.NewFlagSet("google-id info settings", flag.ContinueOnError)
17415
return fs
@@ -205,52 +46,24 @@ func main() {
20546
fs := pluginMeta.flagSet()
20647
fs.IntVar(&pluginMeta.timeout, "timeout", 10, "Request timeout in seconds")
20748
fs.IntVar(&pluginMeta.certsRefreshInterval, "certs-refresh-interval", 60*60, "Certificates refresh interval in seconds")
208-
20949
fs.Var(&pluginMeta.audience, "audience", "The audience of a token")
21050
fs.Var(&pluginMeta.emailsRegex, "email-regex", "Regex of the email claim")
21151

21252
fs.Parse(os.Args[1:])
21353

214-
logrus.Infof("Plugin metadata %v", pluginMeta)
215-
216-
audience := pluginMeta.audience.asMap()
217-
218-
emailRegex := make([]*regexp.Regexp, 0)
219-
for _, emailRe := range pluginMeta.emailsRegex {
220-
re, err := regexp.Compile(emailRe)
221-
if err != nil {
222-
logrus.Errorf("cannot compile email regex %s: %v", emailRe, err)
223-
os.Exit(1)
224-
}
225-
emailRegex = append(emailRegex, re)
226-
}
227-
228-
logrus.Infof("JWT target audience: %v", audience)
229-
logrus.Infof("JWT emails regexp: %v", emailRegex)
230-
231-
if len(emailRegex) == 0 {
232-
logrus.Errorf("parameter email (regex) is required")
233-
os.Exit(1)
54+
opts := googleidinfo.TokenInfoOptions{
55+
Timeout: pluginMeta.timeout,
56+
CertsRefreshInterval: pluginMeta.certsRefreshInterval,
57+
Audience: pluginMeta.audience,
58+
EmailsRegex: pluginMeta.emailsRegex,
23459
}
23560

236-
tokenInfo := &TokenInfo{timeout: time.Duration(pluginMeta.timeout) * time.Second, audience: audience, emailRegex: emailRegex}
237-
238-
op := func() error {
239-
return tokenInfo.refreshCerts()
240-
}
241-
err := backoff.Retry(op, backoff.WithMaxTries(backoff.NewConstantBackOff(1*time.Second), 3))
61+
tokenInfo, err := googleidinfo.NewTokenInfo(opts)
24262
if err != nil {
243-
logrus.Errorf("getting of google certs failed : %v", err)
63+
logrus.Errorf("cannot initialize googleid-info provider: %v", err)
24464
os.Exit(1)
24565
}
24666

247-
certsRefresher := &CertsRefresher{
248-
tokenInfo: tokenInfo,
249-
stopChannel: make(chan struct{}, 1),
250-
}
251-
252-
go certsRefresher.refreshLoop(time.Duration(pluginMeta.certsRefreshInterval) * time.Second)
253-
25467
plugin.Serve(&plugin.ServeConfig{
25568
HandshakeConfig: shared.Handshake,
25669
Plugins: map[string]plugin.Plugin{
@@ -260,47 +73,3 @@ func main() {
26073
GRPCServer: plugin.DefaultGRPCServer,
26174
})
26275
}
263-
264-
type CertsRefresher struct {
265-
tokenInfo *TokenInfo
266-
stopChannel chan struct{}
267-
}
268-
269-
func (p *CertsRefresher) refreshLoop(interval time.Duration) {
270-
defer func() {
271-
if r := recover(); r != nil {
272-
var ok bool
273-
err, ok := r.(error)
274-
if ok {
275-
logrus.Error(fmt.Sprintf("certs refresh loop error %v", err))
276-
}
277-
}
278-
}()
279-
logrus.Infof("Refreshing certs every: %v", interval)
280-
syncTicker := time.NewTicker(interval)
281-
for {
282-
select {
283-
case <-syncTicker.C:
284-
p.refreshTick()
285-
case <-p.stopChannel:
286-
return
287-
}
288-
}
289-
}
290-
291-
func (p *CertsRefresher) refreshTick() error {
292-
op := func() error {
293-
return p.tokenInfo.refreshCerts()
294-
}
295-
backOff := backoff.NewExponentialBackOff()
296-
backOff.MaxElapsedTime = 30 * time.Minute
297-
backOff.MaxInterval = 2 * time.Minute
298-
err := backoff.Retry(op, backOff)
299-
if err != nil {
300-
return err
301-
}
302-
kids := p.tokenInfo.getPublicKeyIDs()
303-
sort.Strings(kids)
304-
logrus.Infof("Refreshed certs Key IDs: %v", kids)
305-
return nil
306-
}

0 commit comments

Comments
 (0)