Skip to content

Commit 8a5c976

Browse files
author
01393707
committed
Added multi-key support in jwt mode
1 parent fc4c9e2 commit 8a5c976

File tree

3 files changed

+107
-102
lines changed

3 files changed

+107
-102
lines changed

example/server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func main() {
2626
manager.MustTokenStorage(store.NewMemoryTokenStore())
2727

2828
// generate jwt access token
29-
manager.MapAccessGenerate(generates.NewJWTAccessGenerate([]byte("00000000"), jwt.SigningMethodHS512))
29+
manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512))
3030

3131
clientStore := store.NewClientStore()
3232
clientStore.Set("222222", &models.Client{

generates/jwt_access.go

Lines changed: 105 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,105 @@
1-
package generates
2-
3-
import (
4-
"context"
5-
"encoding/base64"
6-
"strings"
7-
"time"
8-
9-
errs "errors"
10-
11-
"github.com/dgrijalva/jwt-go"
12-
"gopkg.in/oauth2.v4"
13-
"gopkg.in/oauth2.v4/errors"
14-
"gopkg.in/oauth2.v4/utils/uuid"
15-
)
16-
17-
// JWTAccessClaims jwt claims
18-
type JWTAccessClaims struct {
19-
jwt.StandardClaims
20-
}
21-
22-
// Valid claims verification
23-
func (a *JWTAccessClaims) Valid() error {
24-
if time.Unix(a.ExpiresAt, 0).Before(time.Now()) {
25-
return errors.ErrInvalidAccessToken
26-
}
27-
return nil
28-
}
29-
30-
// NewJWTAccessGenerate create to generate the jwt access token instance
31-
func NewJWTAccessGenerate(key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
32-
return &JWTAccessGenerate{
33-
SignedKey: key,
34-
SignedMethod: method,
35-
}
36-
}
37-
38-
// JWTAccessGenerate generate the jwt access token
39-
type JWTAccessGenerate struct {
40-
SignedKey []byte
41-
SignedMethod jwt.SigningMethod
42-
}
43-
44-
// Token based on the UUID generated token
45-
func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
46-
claims := &JWTAccessClaims{
47-
StandardClaims: jwt.StandardClaims{
48-
Audience: data.Client.GetID(),
49-
Subject: data.UserID,
50-
ExpiresAt: data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn()).Unix(),
51-
},
52-
}
53-
54-
token := jwt.NewWithClaims(a.SignedMethod, claims)
55-
var key interface{}
56-
if a.isEs() {
57-
v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
58-
if err != nil {
59-
return "", "", err
60-
}
61-
key = v
62-
} else if a.isRsOrPS() {
63-
v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
64-
if err != nil {
65-
return "", "", err
66-
}
67-
key = v
68-
} else if a.isHs() {
69-
key = a.SignedKey
70-
} else {
71-
return "", "", errs.New("unsupported sign method")
72-
}
73-
74-
access, err := token.SignedString(key)
75-
if err != nil {
76-
return "", "", err
77-
}
78-
refresh := ""
79-
80-
if isGenRefresh {
81-
refresh = base64.URLEncoding.EncodeToString(uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).Bytes())
82-
refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
83-
}
84-
85-
return access, refresh, nil
86-
}
87-
88-
func (a *JWTAccessGenerate) isEs() bool {
89-
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
90-
}
91-
92-
func (a *JWTAccessGenerate) isRsOrPS() bool {
93-
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
94-
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
95-
return isRs || isPs
96-
}
97-
98-
func (a *JWTAccessGenerate) isHs() bool {
99-
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
100-
}
1+
package generates
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"strings"
7+
"time"
8+
9+
errs "errors"
10+
11+
"github.com/dgrijalva/jwt-go"
12+
"gopkg.in/oauth2.v4"
13+
"gopkg.in/oauth2.v4/errors"
14+
"gopkg.in/oauth2.v4/utils/uuid"
15+
)
16+
17+
// JWTAccessClaims jwt claims
18+
type JWTAccessClaims struct {
19+
jwt.StandardClaims
20+
}
21+
22+
// Valid claims verification
23+
func (a *JWTAccessClaims) Valid() error {
24+
if time.Unix(a.ExpiresAt, 0).Before(time.Now()) {
25+
return errors.ErrInvalidAccessToken
26+
}
27+
return nil
28+
}
29+
30+
// NewJWTAccessGenerate create to generate the jwt access token instance
31+
func NewJWTAccessGenerate(kid string, key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
32+
return &JWTAccessGenerate{
33+
SignedKeyId: kid,
34+
SignedKey: key,
35+
SignedMethod: method,
36+
}
37+
}
38+
39+
// JWTAccessGenerate generate the jwt access token
40+
type JWTAccessGenerate struct {
41+
SignedKeyId string
42+
SignedKey []byte
43+
SignedMethod jwt.SigningMethod
44+
}
45+
46+
// Token based on the UUID generated token
47+
func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
48+
claims := &JWTAccessClaims{
49+
StandardClaims: jwt.StandardClaims{
50+
Audience: data.Client.GetID(),
51+
Subject: data.UserID,
52+
ExpiresAt: data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn()).Unix(),
53+
},
54+
}
55+
56+
token := jwt.NewWithClaims(a.SignedMethod, claims)
57+
if a.SignedKeyId != "" {
58+
token.Header["kid"] = a.SignedKeyId
59+
}
60+
var key interface{}
61+
if a.isEs() {
62+
v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
63+
if err != nil {
64+
return "", "", err
65+
}
66+
key = v
67+
} else if a.isRsOrPS() {
68+
v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
69+
if err != nil {
70+
return "", "", err
71+
}
72+
key = v
73+
} else if a.isHs() {
74+
key = a.SignedKey
75+
} else {
76+
return "", "", errs.New("unsupported sign method")
77+
}
78+
79+
access, err := token.SignedString(key)
80+
if err != nil {
81+
return "", "", err
82+
}
83+
refresh := ""
84+
85+
if isGenRefresh {
86+
refresh = base64.URLEncoding.EncodeToString(uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).Bytes())
87+
refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
88+
}
89+
90+
return access, refresh, nil
91+
}
92+
93+
func (a *JWTAccessGenerate) isEs() bool {
94+
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
95+
}
96+
97+
func (a *JWTAccessGenerate) isRsOrPS() bool {
98+
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
99+
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
100+
return isRs || isPs
101+
}
102+
103+
func (a *JWTAccessGenerate) isHs() bool {
104+
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
105+
}

generates/jwt_access_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestJWTAccess(t *testing.T) {
2828
},
2929
}
3030

31-
gen := generates.NewJWTAccessGenerate([]byte("00000000"), jwt.SigningMethodHS512)
31+
gen := generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512)
3232
access, refresh, err := gen.Token(context.Background(), data, true)
3333
So(err, ShouldBeNil)
3434
So(access, ShouldNotBeEmpty)

0 commit comments

Comments
 (0)