Skip to content

Commit e9b255d

Browse files
authored
Split session (#25)
1 parent f5ebb9d commit e9b255d

File tree

9 files changed

+246
-88
lines changed

9 files changed

+246
-88
lines changed

common.go

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
package authcontrol
22

33
import (
4+
"cmp"
45
"context"
6+
"crypto/x509"
57
"encoding/json"
8+
"encoding/pem"
69
"errors"
710
"fmt"
811
"net/http"
12+
"strconv"
913
"strings"
1014

1115
"github.com/0xsequence/authcontrol/proto"
16+
"github.com/go-chi/jwtauth/v5"
17+
"github.com/lestrrat-go/jwx/v2/jwa"
18+
"github.com/lestrrat-go/jwx/v2/jwt"
1219
)
1320

1421
const (
@@ -42,10 +49,11 @@ type UserStore interface {
4249
GetUser(ctx context.Context, address string) (user any, isAdmin bool, err error)
4350
}
4451

45-
// ProjectStore is a pluggable backend that verifies if the project exists.
46-
// If the project doesn't exist, it should return nil, nil.
52+
// ProjectStore is a pluggable backend that verifies if a project exists.
53+
// If the project does not exist, it should return nil, nil, nil.
54+
// The optional Auth, when returned, will be used for instead of the standard one.
4755
type ProjectStore interface {
48-
GetProject(ctx context.Context, id uint64) (project any, err error)
56+
GetProject(ctx context.Context, id uint64) (project any, auth *Auth, err error)
4957
}
5058

5159
// Config is a generic map of services/methods to a config value.
@@ -121,3 +129,69 @@ func (a ACL) And(session ...proto.SessionType) ACL {
121129
func (t ACL) Includes(session proto.SessionType) bool {
122130
return t&ACL(1<<session) != 0
123131
}
132+
133+
// NewAuth creates a new Auth HS256 with the given secret.
134+
func NewAuth(secret string) *Auth {
135+
return &Auth{Algorithm: jwa.HS256, Private: []byte(secret)}
136+
}
137+
138+
// Auth is a struct that holds the private and public keys for JWT signing and verification.
139+
type Auth struct {
140+
Algorithm jwa.SignatureAlgorithm
141+
Private []byte
142+
Public []byte
143+
}
144+
145+
// GetVerifier returns a JWTAuth using the private secret when available, otherwise the public key
146+
func (a Auth) GetVerifier(options ...jwt.ValidateOption) (*jwtauth.JWTAuth, error) {
147+
if a.Algorithm == "" {
148+
return nil, fmt.Errorf("missing algorithm")
149+
}
150+
151+
if a.Private != nil {
152+
return jwtauth.New(string(a.Algorithm), a.Private, a.Private, options...), nil
153+
}
154+
155+
if a.Public == nil {
156+
return nil, fmt.Errorf("missing public key")
157+
}
158+
159+
block, _ := pem.Decode(a.Public)
160+
161+
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
162+
if err != nil {
163+
return nil, fmt.Errorf("parse public key: %w", err)
164+
}
165+
166+
return jwtauth.New(a.Algorithm.String(), nil, pub, options...), nil
167+
}
168+
169+
// findProjectClaim looks for the project_id/project claim in the JWT
170+
func findProjectClaim(r *http.Request) (uint64, error) {
171+
raw := cmp.Or(jwtauth.TokenFromHeader(r))
172+
173+
token, err := jwt.ParseString(raw, jwt.WithVerify(false))
174+
if err != nil {
175+
return 0, fmt.Errorf("parse token: %w", err)
176+
}
177+
178+
claims := token.PrivateClaims()
179+
180+
claim := cmp.Or(claims["project_id"], claims["project"])
181+
if claim == nil {
182+
return 0, fmt.Errorf("missing project claim")
183+
}
184+
185+
switch val := claim.(type) {
186+
case float64:
187+
return uint64(val), nil
188+
case string:
189+
v, err := strconv.ParseUint(val, 10, 64)
190+
if err != nil {
191+
return 0, fmt.Errorf("invalid value")
192+
}
193+
return v, nil
194+
default:
195+
return 0, fmt.Errorf("invalid type: %T", val)
196+
}
197+
}

context.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ var (
2020
ctxKeyUser = &contextKey{"User"}
2121
ctxKeyService = &contextKey{"Service"}
2222
ctxKeyAccessKey = &contextKey{"AccessKey"}
23-
ctxKeyProject = &contextKey{"Project"}
2423
ctxKeyProjectID = &contextKey{"ProjectID"}
24+
ctxKeyProject = &contextKey{"Project"}
2525
)
2626

2727
//

middleware.go

Lines changed: 78 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,22 @@ import (
1616

1717
// Options for the authcontrol middleware handlers Session and AccessControl.
1818
type Options struct {
19-
// JWT secret used to verify the JWT token.
19+
// JWTsecret is required, and it is used for the JWT verification.
20+
// If a Project Store is also provided and the request has a project claim,
21+
// it could be replaced by the a specific verifier.
2022
JWTSecret string
2123

24+
// ProjectStore is a pluggable backends that verifies if the project from the claim exists.
25+
// When provived, it checks the Project from the JWT, and can override the JWT Auth.
26+
ProjectStore ProjectStore
27+
2228
// AccessKeyFuncs are used to extract the access key from the request.
2329
AccessKeyFuncs []AccessKeyFunc
2430

2531
// UserStore is a pluggable backends that verifies if the account exists.
32+
// When provided, it can upgrade a Wallet session to a User or Admin session.
2633
UserStore UserStore
2734

28-
// ProjectStore is a pluggable backends that verifies if the project exists.
29-
ProjectStore ProjectStore
30-
3135
// ErrHandler is a function that is used to handle and respond to errors.
3236
ErrHandler ErrHandler
3337
}
@@ -46,34 +50,47 @@ func (o *Options) ApplyDefaults() {
4650
}
4751
}
4852

49-
func Session(cfg Options) func(next http.Handler) http.Handler {
53+
func VerifyToken(cfg Options) func(next http.Handler) http.Handler {
5054
cfg.ApplyDefaults()
51-
auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil, jwt.WithAcceptableSkew(2*time.Minute))
55+
jwtOptions := []jwt.ValidateOption{
56+
jwt.WithAcceptableSkew(2 * time.Minute),
57+
}
5258

5359
return func(next http.Handler) http.Handler {
5460
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5561
ctx := r.Context()
5662

57-
// check if the request already contains session, if it does then continue
58-
if _, ok := GetSessionType(ctx); ok {
59-
next.ServeHTTP(w, r)
60-
return
61-
}
63+
auth := NewAuth(cfg.JWTSecret)
6264

63-
var (
64-
sessionType proto.SessionType
65-
accessKey string
66-
token jwt.Token
67-
)
65+
if cfg.ProjectStore != nil {
66+
projectID, err := findProjectClaim(r)
67+
if err != nil {
68+
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project claim: %w", err))
69+
return
70+
}
6871

69-
for _, f := range cfg.AccessKeyFuncs {
70-
if accessKey = f(r); accessKey != "" {
71-
break
72+
project, _auth, err := cfg.ProjectStore.GetProject(ctx, projectID)
73+
if err != nil {
74+
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project: %w", err))
75+
return
76+
}
77+
if project == nil {
78+
cfg.ErrHandler(r, w, proto.ErrProjectNotFound)
79+
return
80+
}
81+
if _auth != nil {
82+
auth = _auth
7283
}
84+
ctx = WithProject(ctx, project)
7385
}
7486

75-
// Verify JWT token and validate its claims.
76-
token, err := jwtauth.VerifyRequest(auth, r, jwtauth.TokenFromHeader)
87+
jwtAuth, err := auth.GetVerifier(jwtOptions...)
88+
if err != nil {
89+
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get verifier: %w", err))
90+
return
91+
}
92+
93+
token, err := jwtauth.VerifyRequest(jwtAuth, r, jwtauth.TokenFromHeader)
7794
if err != nil {
7895
if errors.Is(err, jwtauth.ErrExpired) {
7996
cfg.ErrHandler(r, w, proto.ErrSessionExpired)
@@ -89,7 +106,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
89106
if token != nil {
90107
claims, err := token.AsMap(ctx)
91108
if err != nil {
92-
cfg.ErrHandler(r, w, err)
109+
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("invalid token: %w", err))
93110
return
94111
}
95112

@@ -102,6 +119,44 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
102119
}
103120
}
104121

122+
ctx = jwtauth.NewContext(ctx, token, nil)
123+
}
124+
125+
next.ServeHTTP(w, r.WithContext(ctx))
126+
})
127+
}
128+
}
129+
130+
func Session(cfg Options) func(next http.Handler) http.Handler {
131+
cfg.ApplyDefaults()
132+
133+
return func(next http.Handler) http.Handler {
134+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
135+
ctx := r.Context()
136+
137+
// if a custom middleware already sets the session type, skip this middleware
138+
if _, ok := GetSessionType(ctx); ok {
139+
next.ServeHTTP(w, r)
140+
return
141+
}
142+
143+
var (
144+
accessKey string
145+
sessionType proto.SessionType
146+
)
147+
148+
for _, f := range cfg.AccessKeyFuncs {
149+
if accessKey = f(r); accessKey != "" {
150+
break
151+
}
152+
}
153+
154+
_, claims, err := jwtauth.FromContext(ctx)
155+
if err != nil {
156+
cfg.ErrHandler(r, w, err)
157+
return
158+
}
159+
if claims != nil {
105160
serviceClaim, _ := claims["service"].(string)
106161
accountClaim, _ := claims["account"].(string)
107162
adminClaim, _ := claims["admin"].(bool)
@@ -140,20 +195,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
140195
}
141196

142197
if projectClaim > 0 {
143-
projectID := uint64(projectClaim)
144-
if cfg.ProjectStore != nil {
145-
project, err := cfg.ProjectStore.GetProject(ctx, projectID)
146-
if err != nil {
147-
cfg.ErrHandler(r, w, err)
148-
return
149-
}
150-
if project == nil {
151-
cfg.ErrHandler(r, w, proto.ErrProjectNotFound)
152-
return
153-
}
154-
ctx = WithProject(ctx, project)
155-
}
156-
ctx = WithProjectID(ctx, projectID)
198+
ctx = WithProjectID(ctx, uint64(projectClaim))
157199
sessionType = proto.SessionType_Project
158200
}
159201
}

0 commit comments

Comments
 (0)