Skip to content
Merged
80 changes: 77 additions & 3 deletions common.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
package authcontrol

import (
"cmp"
"context"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/http"
"strconv"
"strings"

"github.com/0xsequence/authcontrol/proto"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
)

const (
Expand Down Expand Up @@ -42,10 +49,11 @@ type UserStore interface {
GetUser(ctx context.Context, address string) (user any, isAdmin bool, err error)
}

// ProjectStore is a pluggable backend that verifies if the project exists.
// If the project doesn't exist, it should return nil, nil.
// ProjectStore is a pluggable backend that verifies if a project exists.
// If the project does not exist, it should return nil, nil, nil.
// The optional Auth, when returned, will be used for instead of the standard one.
type ProjectStore interface {
GetProject(ctx context.Context, id uint64) (project any, err error)
GetProject(ctx context.Context, id uint64) (project any, auth *Auth, err error)
}

// Config is a generic map of services/methods to a config value.
Expand Down Expand Up @@ -121,3 +129,69 @@ func (a ACL) And(session ...proto.SessionType) ACL {
func (t ACL) Includes(session proto.SessionType) bool {
return t&ACL(1<<session) != 0
}

// NewAuth creates a new Auth HS256 with the given secret.
func NewAuth(secret string) *Auth {
return &Auth{Algorithm: jwa.HS256, Private: []byte(secret)}
}

// Auth is a struct that holds the private and public keys for JWT signing and verification.
type Auth struct {
Algorithm jwa.SignatureAlgorithm
Private []byte
Public []byte
}

// GetVerifier returns a JWTAuth using the private secret when available, otherwise the public key
func (a Auth) GetVerifier(options ...jwt.ValidateOption) (*jwtauth.JWTAuth, error) {
if a.Algorithm == "" {
return nil, fmt.Errorf("missing algorithm")
}

if a.Private != nil {
return jwtauth.New(string(a.Algorithm), a.Private, a.Private, options...), nil
}

if a.Public == nil {
return nil, fmt.Errorf("missing public key")
}

block, _ := pem.Decode(a.Public)

pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse public key: %w", err)
}

return jwtauth.New(a.Algorithm.String(), nil, pub, options...), nil
}

// findProjectClaim looks for the project_id/project claim in the JWT
func findProjectClaim(r *http.Request) (uint64, error) {
raw := cmp.Or(jwtauth.TokenFromHeader(r))

token, err := jwt.ParseString(raw, jwt.WithVerify(false))
if err != nil {
return 0, fmt.Errorf("parse token: %w", err)
}

claims := token.PrivateClaims()

claim := cmp.Or(claims["project_id"], claims["project"])
if claim == nil {
return 0, fmt.Errorf("missing project claim")
}

switch val := claim.(type) {
case float64:
return uint64(val), nil
case string:
v, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid value")
}
return v, nil
default:
return 0, fmt.Errorf("invalid type: %T", val)
}
}
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ var (
ctxKeyUser = &contextKey{"User"}
ctxKeyService = &contextKey{"Service"}
ctxKeyAccessKey = &contextKey{"AccessKey"}
ctxKeyProject = &contextKey{"Project"}
ctxKeyProjectID = &contextKey{"ProjectID"}
ctxKeyProject = &contextKey{"Project"}
)

//
Expand Down
114 changes: 78 additions & 36 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ import (

// Options for the authcontrol middleware handlers Session and AccessControl.
type Options struct {
// JWT secret used to verify the JWT token.
// JWTsecret is required, and it is used for the JWT verification.
// If a Project Store is also provided and the request has a project claim,
// it could be replaced by the a specific verifier.
JWTSecret string

// ProjectStore is a pluggable backends that verifies if the project from the claim exists.
// When provived, it checks the Project from the JWT, and can override the JWT Auth.
ProjectStore ProjectStore

// AccessKeyFuncs are used to extract the access key from the request.
AccessKeyFuncs []AccessKeyFunc

// UserStore is a pluggable backends that verifies if the account exists.
// When provided, it can upgrade a Wallet session to a User or Admin session.
UserStore UserStore

// ProjectStore is a pluggable backends that verifies if the project exists.
ProjectStore ProjectStore

// ErrHandler is a function that is used to handle and respond to errors.
ErrHandler ErrHandler
}
Expand All @@ -46,34 +50,47 @@ func (o *Options) ApplyDefaults() {
}
}

func Session(cfg Options) func(next http.Handler) http.Handler {
func VerifyToken(cfg Options) func(next http.Handler) http.Handler {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we could split Options struct into two types -- one for Verifier and one for Sessions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imho, this middleware needs nothing but JWTSecret and ErrHandler. It should always verify HS256 JWT from both Authorization header and Cookie.

I don't see the need for AuthProvider interface. Is it useful for anything else? I think it just locks us down to this implementation -- switching algorithms, jwks or supporting secret rotation or multiple algorithms would be more difficult.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imho, stack/api's ProjectJWTVerifier should be a a separate implementation that doesn't have anything in common with this Verifier except for passing the token into context via ctx = jwtauth.NewContext(ctx, token, nil)

cfg.ApplyDefaults()
auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil, jwt.WithAcceptableSkew(2*time.Minute))
jwtOptions := []jwt.ValidateOption{
jwt.WithAcceptableSkew(2 * time.Minute),
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// check if the request already contains session, if it does then continue
if _, ok := GetSessionType(ctx); ok {
next.ServeHTTP(w, r)
return
}
auth := NewAuth(cfg.JWTSecret)

var (
sessionType proto.SessionType
accessKey string
token jwt.Token
)
if cfg.ProjectStore != nil {
projectID, err := findProjectClaim(r)
if err != nil {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project claim: %w", err))
return
}

for _, f := range cfg.AccessKeyFuncs {
if accessKey = f(r); accessKey != "" {
break
project, _auth, err := cfg.ProjectStore.GetProject(ctx, projectID)
if err != nil {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get project: %w", err))
return
}
if project == nil {
cfg.ErrHandler(r, w, proto.ErrProjectNotFound)
return
}
if _auth != nil {
auth = _auth
}
ctx = WithProject(ctx, project)
}

// Verify JWT token and validate its claims.
token, err := jwtauth.VerifyRequest(auth, r, jwtauth.TokenFromHeader)
jwtAuth, err := auth.GetVerifier(jwtOptions...)
if err != nil {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get verifier: %w", err))
return
}

token, err := jwtauth.VerifyRequest(jwtAuth, r, jwtauth.TokenFromHeader)
if err != nil {
if errors.Is(err, jwtauth.ErrExpired) {
cfg.ErrHandler(r, w, proto.ErrSessionExpired)
Expand All @@ -89,7 +106,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
if token != nil {
claims, err := token.AsMap(ctx)
if err != nil {
cfg.ErrHandler(r, w, err)
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("invalid token: %w", err))
return
}

Expand All @@ -102,6 +119,44 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
}
}

ctx = jwtauth.NewContext(ctx, token, nil)
}

next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

func Session(cfg Options) func(next http.Handler) http.Handler {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Session needs ErrorHandler, UserStore and a way to fetch AccessKey.

cfg.ApplyDefaults()

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// if a custom middleware already sets the session type, skip this middleware
if _, ok := GetSessionType(ctx); ok {
next.ServeHTTP(w, r)
return
}

var (
accessKey string
sessionType proto.SessionType
)

for _, f := range cfg.AccessKeyFuncs {
if accessKey = f(r); accessKey != "" {
break
}
}

_, claims, err := jwtauth.FromContext(ctx)
if err != nil {
cfg.ErrHandler(r, w, err)
return
}
if claims != nil {
serviceClaim, _ := claims["service"].(string)
accountClaim, _ := claims["account"].(string)
adminClaim, _ := claims["admin"].(bool)
Expand Down Expand Up @@ -140,20 +195,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
}

if projectClaim > 0 {
projectID := uint64(projectClaim)
if cfg.ProjectStore != nil {
project, err := cfg.ProjectStore.GetProject(ctx, projectID)
if err != nil {
cfg.ErrHandler(r, w, err)
return
}
if project == nil {
cfg.ErrHandler(r, w, proto.ErrProjectNotFound)
return
}
ctx = WithProject(ctx, project)
}
ctx = WithProjectID(ctx, projectID)
ctx = WithProjectID(ctx, uint64(projectClaim))
sessionType = proto.SessionType_Project
}
}
Expand Down
Loading
Loading