11package main
22
33import (
4- "context"
5- "crypto/rsa"
6- "errors"
74 "flag"
85 "fmt"
9- "github.com/cenkalti/backoff"
10- "github.com/grepplabs/kafka-proxy/pkg/apis"
11- "github.com/grepplabs/kafka-proxy/pkg/libs/googleid"
6+ "github.com/grepplabs/kafka-proxy/pkg/libs/googleid-info"
127 "github.com/grepplabs/kafka-proxy/plugin/gateway-server/shared"
138 "github.com/hashicorp/go-plugin"
149 "github.com/sirupsen/logrus"
15- "golang.org/x/oauth2/jws"
1610 "os"
17- "regexp"
18- "sort"
19- "sync"
20- "time"
2111)
2212
23- const (
24- StatusOK = 0
25- StatusEmptyToken = 1
26- StatusParseJWTFailed = 2
27- StatusNoIssueTimeInToken = 3
28- StatusNoExpirationTimeInToken = 4
29- StatusPublicKeyNotFound = 5
30- StatusWrongIssuer = 6
31- StatusWrongSignature = 7
32- StatusTokenTooEarly = 8
33- StatusTokenExpired = 9
34- StatusWrongAudience = 10
35- StatusWrongEmail = 11
36- )
37-
38- var (
39- clockSkew = 1 * time .Minute
40- nowFn = time .Now
41-
42- issuers = map [string ]struct {}{
43- "accounts.google.com" : {},
44- "https://accounts.google.com" : {},
45- }
46- )
47-
48- type TokenInfo struct {
49- timeout time.Duration
50- audience map [string ]struct {}
51- emailRegex []* regexp.Regexp
52-
53- publicKeys map [string ]* rsa.PublicKey
54- l sync.RWMutex
55- }
56-
57- func (p * TokenInfo ) getPublicKey (kid string ) * rsa.PublicKey {
58- p .l .RLock ()
59- defer p .l .RUnlock ()
60-
61- return p .publicKeys [kid ]
62- }
63-
64- func (p * TokenInfo ) getPublicKeyIDs () []string {
65- p .l .RLock ()
66- defer p .l .RUnlock ()
67- kids := make ([]string , 0 )
68- for kid := range p .publicKeys {
69- kids = append (kids , kid )
70- }
71- return kids
72- }
73-
74- func (p * TokenInfo ) setPublicKeys (publicKeys map [string ]* rsa.PublicKey ) {
75- p .l .Lock ()
76- defer p .l .Unlock ()
77-
78- p .publicKeys = publicKeys
79- }
80-
81- func (p * TokenInfo ) refreshCerts () error {
82- ctx , cancel := context .WithTimeout (context .Background (), p .timeout )
83- defer cancel ()
84-
85- certs , err := googleid .GetCerts (ctx )
86- if err != nil {
87- return err
88- }
89- if len (certs .Keys ) == 0 {
90- return errors .New ("certs keys must not be empty" )
91- }
92-
93- publicKeys := make (map [string ]* rsa.PublicKey )
94- for _ , key := range certs .Keys {
95- publicKey , err := key .GetPublicKey ()
96- if err != nil {
97- return fmt .Errorf ("cannot parse public key: %v" , key .Kid )
98- }
99- publicKeys [key .Kid ] = publicKey
100- }
101-
102- p .setPublicKeys (publicKeys )
103-
104- return nil
105- }
106-
107- func (p * TokenInfo ) VerifyToken (parent context.Context , request apis.VerifyRequest ) (apis.VerifyResponse , error ) {
108- // logrus.Infof("VerifyToken: %s", request.Token)
109- if request .Token == "" {
110- return getVerifyResponseResponse (StatusEmptyToken )
111- }
112-
113- token , err := googleid .ParseJWT (request .Token )
114- if err != nil {
115- return getVerifyResponseResponse (StatusParseJWTFailed )
116- }
117- if _ , ok := issuers [token .ClaimSet .Iss ]; ! ok {
118- return getVerifyResponseResponse (StatusWrongIssuer )
119- }
120- if token .ClaimSet .Iat < 1 {
121- return getVerifyResponseResponse (StatusNoIssueTimeInToken )
122- }
123- if token .ClaimSet .Exp < 1 {
124- return getVerifyResponseResponse (StatusNoExpirationTimeInToken )
125- }
126-
127- earliest := token .ClaimSet .Iat - int64 (clockSkew .Seconds ())
128- latest := token .ClaimSet .Exp + int64 (clockSkew .Seconds ())
129- unix := nowFn ().Unix ()
130-
131- if unix < earliest {
132- return getVerifyResponseResponse (StatusTokenTooEarly )
133- }
134- if unix > latest {
135- return getVerifyResponseResponse (StatusTokenExpired )
136- }
137-
138- if len (p .audience ) != 0 {
139- if _ , ok := p .audience [token .ClaimSet .Aud ]; ! ok {
140- return getVerifyResponseResponse (StatusWrongAudience )
141- }
142- }
143- if ! p .checkEmail (token .ClaimSet .Email ) {
144- return getVerifyResponseResponse (StatusWrongEmail )
145- }
146-
147- publicKey := p .getPublicKey (token .Header .KeyID )
148- if publicKey == nil {
149- return getVerifyResponseResponse (StatusPublicKeyNotFound )
150- }
151- err = jws .Verify (request .Token , publicKey )
152- if err != nil {
153- return getVerifyResponseResponse (StatusWrongSignature )
154- }
155- return apis.VerifyResponse {Success : true }, nil
156- }
157-
158- func (p * TokenInfo ) checkEmail (email string ) bool {
159- for _ , re := range p .emailRegex {
160- if re .MatchString (email ) {
161- return true
162- }
163- }
164- return false
165- }
166-
167- func getVerifyResponseResponse (status int ) (apis.VerifyResponse , error ) {
168- success := status == StatusOK
169- return apis.VerifyResponse {Success : success , Status : int32 (status )}, nil
170- }
171-
17213func (f * pluginMeta ) flagSet () * flag.FlagSet {
17314 fs := flag .NewFlagSet ("google-id info settings" , flag .ContinueOnError )
17415 return fs
@@ -205,52 +46,24 @@ func main() {
20546 fs := pluginMeta .flagSet ()
20647 fs .IntVar (& pluginMeta .timeout , "timeout" , 10 , "Request timeout in seconds" )
20748 fs .IntVar (& pluginMeta .certsRefreshInterval , "certs-refresh-interval" , 60 * 60 , "Certificates refresh interval in seconds" )
208-
20949 fs .Var (& pluginMeta .audience , "audience" , "The audience of a token" )
21050 fs .Var (& pluginMeta .emailsRegex , "email-regex" , "Regex of the email claim" )
21151
21252 fs .Parse (os .Args [1 :])
21353
214- logrus .Infof ("Plugin metadata %v" , pluginMeta )
215-
216- audience := pluginMeta .audience .asMap ()
217-
218- emailRegex := make ([]* regexp.Regexp , 0 )
219- for _ , emailRe := range pluginMeta .emailsRegex {
220- re , err := regexp .Compile (emailRe )
221- if err != nil {
222- logrus .Errorf ("cannot compile email regex %s: %v" , emailRe , err )
223- os .Exit (1 )
224- }
225- emailRegex = append (emailRegex , re )
226- }
227-
228- logrus .Infof ("JWT target audience: %v" , audience )
229- logrus .Infof ("JWT emails regexp: %v" , emailRegex )
230-
231- if len (emailRegex ) == 0 {
232- logrus .Errorf ("parameter email (regex) is required" )
233- os .Exit (1 )
54+ opts := googleidinfo.TokenInfoOptions {
55+ Timeout : pluginMeta .timeout ,
56+ CertsRefreshInterval : pluginMeta .certsRefreshInterval ,
57+ Audience : pluginMeta .audience ,
58+ EmailsRegex : pluginMeta .emailsRegex ,
23459 }
23560
236- tokenInfo := & TokenInfo {timeout : time .Duration (pluginMeta .timeout ) * time .Second , audience : audience , emailRegex : emailRegex }
237-
238- op := func () error {
239- return tokenInfo .refreshCerts ()
240- }
241- err := backoff .Retry (op , backoff .WithMaxTries (backoff .NewConstantBackOff (1 * time .Second ), 3 ))
61+ tokenInfo , err := googleidinfo .NewTokenInfo (opts )
24262 if err != nil {
243- logrus .Errorf ("getting of google certs failed : %v" , err )
63+ logrus .Errorf ("cannot initialize googleid-info provider : %v" , err )
24464 os .Exit (1 )
24565 }
24666
247- certsRefresher := & CertsRefresher {
248- tokenInfo : tokenInfo ,
249- stopChannel : make (chan struct {}, 1 ),
250- }
251-
252- go certsRefresher .refreshLoop (time .Duration (pluginMeta .certsRefreshInterval ) * time .Second )
253-
25467 plugin .Serve (& plugin.ServeConfig {
25568 HandshakeConfig : shared .Handshake ,
25669 Plugins : map [string ]plugin.Plugin {
@@ -260,47 +73,3 @@ func main() {
26073 GRPCServer : plugin .DefaultGRPCServer ,
26174 })
26275}
263-
264- type CertsRefresher struct {
265- tokenInfo * TokenInfo
266- stopChannel chan struct {}
267- }
268-
269- func (p * CertsRefresher ) refreshLoop (interval time.Duration ) {
270- defer func () {
271- if r := recover (); r != nil {
272- var ok bool
273- err , ok := r .(error )
274- if ok {
275- logrus .Error (fmt .Sprintf ("certs refresh loop error %v" , err ))
276- }
277- }
278- }()
279- logrus .Infof ("Refreshing certs every: %v" , interval )
280- syncTicker := time .NewTicker (interval )
281- for {
282- select {
283- case <- syncTicker .C :
284- p .refreshTick ()
285- case <- p .stopChannel :
286- return
287- }
288- }
289- }
290-
291- func (p * CertsRefresher ) refreshTick () error {
292- op := func () error {
293- return p .tokenInfo .refreshCerts ()
294- }
295- backOff := backoff .NewExponentialBackOff ()
296- backOff .MaxElapsedTime = 30 * time .Minute
297- backOff .MaxInterval = 2 * time .Minute
298- err := backoff .Retry (op , backOff )
299- if err != nil {
300- return err
301- }
302- kids := p .tokenInfo .getPublicKeyIDs ()
303- sort .Strings (kids )
304- logrus .Infof ("Refreshed certs Key IDs: %v" , kids )
305- return nil
306- }
0 commit comments