Skip to content
Merged
6 changes: 0 additions & 6 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ type UserStore interface {
GetUser(ctx context.Context, address string) (user any, isAdmin bool, err error)
}

// ProjectStore is a pluggable backend that verifies if the project exists.
// If the project doesn't exist, it should return nil, nil.
type ProjectStore interface {
GetProject(ctx context.Context, id uint64) (project any, err error)
}

// Config is a generic map of services/methods to a config value.
// map[service]map[method]T
type Config[T any] map[string]map[string]T
Expand Down
18 changes: 0 additions & 18 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ var (
ctxKeyUser = &contextKey{"User"}
ctxKeyService = &contextKey{"Service"}
ctxKeyAccessKey = &contextKey{"AccessKey"}
ctxKeyProject = &contextKey{"Project"}
ctxKeyProjectID = &contextKey{"ProjectID"}
)

Expand Down Expand Up @@ -127,20 +126,3 @@ func GetProjectID(ctx context.Context) (uint64, bool) {
v, ok := ctx.Value(ctxKeyProjectID).(uint64)
return v, ok
}

//
// Project
//

// WithProject adds the project to the context.
//
// TODO: Deprecate this in favor of Session middleware with a JWT token.
func WithProject(ctx context.Context, project any) context.Context {
return context.WithValue(ctx, ctxKeyProject, project)
}

// GetProject returns the project from the context.
func GetProject[T any](ctx context.Context) (*T, bool) {
v, ok := ctx.Value(ctxKeyProject).(*T)
return v, ok
}
85 changes: 48 additions & 37 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@ import (
// Options for the authcontrol middleware handlers Session and AccessControl.
type Options struct {
// JWT secret used to verify the JWT token.
JWTSecret string
Verifier AuthProvider

// AccessKeyFuncs are used to extract the access key from the request.
AccessKeyFuncs []AccessKeyFunc

// UserStore is a pluggable backends that verifies if the account exists.
UserStore UserStore

// ProjectStore is a pluggable backends that verifies if the project exists.
ProjectStore ProjectStore

// ErrHandler is a function that is used to handle and respond to errors.
ErrHandler ErrHandler
}
Expand All @@ -46,33 +43,22 @@ func (o *Options) ApplyDefaults() {
}
}

func Session(cfg Options) func(next http.Handler) http.Handler {
func VerifyToken(cfg Options) func(next http.Handler) http.Handler {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we could split Options struct into two types -- one for Verifier and one for Sessions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imho, this middleware needs nothing but JWTSecret and ErrHandler. It should always verify HS256 JWT from both Authorization header and Cookie.

I don't see the need for AuthProvider interface. Is it useful for anything else? I think it just locks us down to this implementation -- switching algorithms, jwks or supporting secret rotation or multiple algorithms would be more difficult.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imho, stack/api's ProjectJWTVerifier should be a a separate implementation that doesn't have anything in common with this Verifier except for passing the token into context via ctx = jwtauth.NewContext(ctx, token, nil)

cfg.ApplyDefaults()
auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil, jwt.WithAcceptableSkew(2*time.Minute))
jwtOptions := []jwt.ValidateOption{
jwt.WithAcceptableSkew(2 * time.Minute),
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// check if the request already contains session, if it does then continue
if _, ok := GetSessionType(ctx); ok {
next.ServeHTTP(w, r)
auth, err := cfg.Verifier.GetJWTAuth(r, jwtOptions...)
if err != nil {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get verifier: %w", err))
return
}

var (
sessionType proto.SessionType
accessKey string
token jwt.Token
)

for _, f := range cfg.AccessKeyFuncs {
if accessKey = f(r); accessKey != "" {
break
}
}

// Verify JWT token and validate its claims.
token, err := jwtauth.VerifyRequest(auth, r, jwtauth.TokenFromHeader)
if err != nil {
if errors.Is(err, jwtauth.ErrExpired) {
Expand All @@ -89,7 +75,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
if token != nil {
claims, err := token.AsMap(ctx)
if err != nil {
cfg.ErrHandler(r, w, err)
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("invalid token: %w", err))
return
}

Expand All @@ -102,6 +88,44 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
}
}

ctx = jwtauth.NewContext(ctx, token, nil)
}

next.ServeHTTP(w, r.WithContext(ctx))
})
}
}

func Session(cfg Options) func(next http.Handler) http.Handler {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Session needs ErrorHandler, UserStore and a way to fetch AccessKey.

cfg.ApplyDefaults()

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// check if the request already contains session, if it does then continue
if _, ok := GetSessionType(ctx); ok {
next.ServeHTTP(w, r)
return
}

var (
accessKey string
sessionType proto.SessionType
)

for _, f := range cfg.AccessKeyFuncs {
if accessKey = f(r); accessKey != "" {
break
}
}

_, claims, err := jwtauth.FromContext(ctx)
if err != nil {
cfg.ErrHandler(r, w, err)
return
}
if claims != nil {
serviceClaim, _ := claims["service"].(string)
accountClaim, _ := claims["account"].(string)
adminClaim, _ := claims["admin"].(bool)
Expand Down Expand Up @@ -140,20 +164,7 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
}

if projectClaim > 0 {
projectID := uint64(projectClaim)
if cfg.ProjectStore != nil {
project, err := cfg.ProjectStore.GetProject(ctx, projectID)
if err != nil {
cfg.ErrHandler(r, w, err)
return
}
if project == nil {
cfg.ErrHandler(r, w, proto.ErrProjectNotFound)
return
}
ctx = WithProject(ctx, project)
}
ctx = WithProjectID(ctx, projectID)
ctx = WithProjectID(ctx, uint64(projectClaim))
sessionType = proto.SessionType_Project
}
}
Expand Down
123 changes: 88 additions & 35 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@ package authcontrol_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"strings"
"testing"
"time"

"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/0xsequence/authcontrol"
"github.com/0xsequence/authcontrol/proto"
Expand All @@ -31,17 +37,6 @@ func (m MockUserStore) GetUser(ctx context.Context, address string) (user any, i
return struct{}{}, v, nil
}

// MockProjectStore is a simple in-memory Project store for testing, it stores the project.
type MockProjectStore map[uint64]struct{}

// GetProject returns the project from the store.
func (m MockProjectStore) GetProject(ctx context.Context, id uint64) (project any, err error) {
if _, ok := m[id]; !ok {
return nil, nil
}
return struct{}{}, nil
}

func TestSession(t *testing.T) {
const (
MethodPublic = "MethodPublic"
Expand Down Expand Up @@ -75,22 +70,19 @@ func TestSession(t *testing.T) {
)

options := authcontrol.Options{
JWTSecret: JWTSecret,
Verifier: authcontrol.NewAuth(JWTSecret),
UserStore: MockUserStore{
UserAddress: false,
AdminAddress: true,
},
ProjectStore: MockProjectStore{
ProjectID: struct{}{},
},
AccessKeyFuncs: []authcontrol.AccessKeyFunc{keyFunc},
}

r := chi.NewRouter()
r.Use(
authcontrol.Session(options),
authcontrol.AccessControl(ACLConfig, options),
)
r.Use(authcontrol.VerifyToken(options))
r.Use(authcontrol.Session(options))
r.Use(authcontrol.AccessControl(ACLConfig, options))

r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

ctx := context.Background()
Expand Down Expand Up @@ -199,18 +191,16 @@ func TestInvalid(t *testing.T) {
)

options := authcontrol.Options{
JWTSecret: JWTSecret,
Verifier: authcontrol.NewAuth(JWTSecret),
UserStore: MockUserStore{
UserAddress: false,
AdminAddress: true,
},
ProjectStore: MockProjectStore{
ProjectID: struct{}{},
},
AccessKeyFuncs: []authcontrol.AccessKeyFunc{keyFunc},
}

r := chi.NewRouter()
r.Use(authcontrol.VerifyToken(options))
r.Use(authcontrol.Session(options))
r.Use(authcontrol.AccessControl(ACLConfig, options))

Expand Down Expand Up @@ -281,12 +271,6 @@ func TestInvalid(t *testing.T) {
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(expiredJWT))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)

// Invalid Project
wrongProject := authcontrol.S2SToken(JWTSecret, map[string]any{"account": WalletAddress, "project_id": ProjectID + 1})
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), jwt(wrongProject))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrProjectNotFound)
}

func TestCustomErrHandler(t *testing.T) {
Expand Down Expand Up @@ -318,7 +302,7 @@ func TestCustomErrHandler(t *testing.T) {
}

opts := authcontrol.Options{
JWTSecret: JWTSecret,
Verifier: authcontrol.NewAuth(JWTSecret),
UserStore: MockUserStore{
UserAddress: false,
AdminAddress: true,
Expand All @@ -336,10 +320,10 @@ func TestCustomErrHandler(t *testing.T) {
}

r := chi.NewRouter()
r.Use(
authcontrol.Session(opts),
authcontrol.AccessControl(ACLConfig, opts),
)
r.Use(authcontrol.VerifyToken(opts))
r.Use(authcontrol.Session(opts))
r.Use(authcontrol.AccessControl(ACLConfig, opts))

r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

var claims map[string]any
Expand All @@ -360,10 +344,11 @@ func TestOrigin(t *testing.T) {
ctx := context.Background()

opts := authcontrol.Options{
JWTSecret: JWTSecret,
Verifier: authcontrol.NewAuth(JWTSecret),
}

r := chi.NewRouter()
r.Use(authcontrol.VerifyToken(opts))
r.Use(authcontrol.Session(opts))
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

Expand All @@ -387,3 +372,71 @@ func TestOrigin(t *testing.T) {
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)
}

type MockAuthStore map[uint64]authcontrol.StaticAuth

func (m MockAuthStore) GetJWTAuth(ctx context.Context, projectID uint64) (*authcontrol.StaticAuth, error) {
auth, ok := m[projectID]
if !ok {
return nil, nil
}
return &auth, nil
}

func TestProjectVerifier(t *testing.T) {
ctx := context.Background()

authStore := MockAuthStore{}

opts := authcontrol.Options{
Verifier: authcontrol.ProjectProvider{
Store: authStore,
},
}

r := chi.NewRouter()
r.Use(authcontrol.VerifyToken(opts))
r.Use(authcontrol.Session(opts))
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

projectID := uint64(7)

authStore[projectID] = authcontrol.StaticAuth{
Algorithm: authcontrol.DefaultAlgorithm,
Private: []byte(JWTSecret),
}

token := authcontrol.S2SToken(JWTSecret, map[string]any{
"project_id": projectID,
})

ok, err := executeRequest(t, ctx, r, "", jwt(token))
assert.True(t, ok)
assert.NoError(t, err)

privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
require.NoError(t, privateKey.Validate())

publicRaw, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)

public := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: publicRaw,
})

authStore[projectID] = authcontrol.StaticAuth{
Algorithm: "RS256",
Public: public,
}

_, token, err = jwtauth.New("RS256", privateKey, nil).Encode(map[string]any{
"project_id": projectID,
})
require.NoError(t, err)

ok, err = executeRequest(t, ctx, r, "", jwt(token))
assert.True(t, ok)
assert.NoError(t, err)
}
Loading