Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 96 additions & 45 deletions auth/auth_test.go

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions auth/collection_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type CollectionChannelAPI interface {
SetCollectionChannelHistory(scope, collection string, history TimedSetHistory)

// Returns true if the Principal has access to the given channel.
CanSeeCollectionChannel(scope, collection, channel string) bool
CanSeeCollectionChannel(scope, collection, channel string) (bool, error)

// Retrieve invalidation sequence for a collection
getCollectionChannelInvalSeq(scope, collection string) uint64
Expand All @@ -52,7 +52,7 @@ type CollectionChannelAPI interface {

// If the Principal has access to the given collection's channel, returns the sequence number at which
// access was granted; else returns zero.
canSeeCollectionChannelSince(scope, collection, channel string) uint64
canSeeCollectionChannelSince(scope, collection, channel string) (uint64, error)

// Returns an error if the Principal does not have access to all the channels in the set, for the specified collection.
authorizeAllCollectionChannels(scope, collection string, channels base.Set) error
Expand All @@ -76,23 +76,23 @@ type UserCollectionChannelAPI interface {
SetCollectionJWTChannels(scope, collection string, channels ch.TimedSet, seq uint64)

// Retrieves revoked channels for a collection, based on the given since value
RevokedCollectionChannels(scope, collection string, since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels
RevokedCollectionChannels(scope, collection string, since uint64, lowSeq uint64, triggeredBy uint64) (RevokedChannels, error)

// Obtains the period over which the user had access to the given collection's channel. Either directly or via a role.
CollectionChannelGrantedPeriods(scope, collection, chanName string) ([]GrantHistorySequencePair, error)

// Every channel the user has access to in the collection, including those inherited from Roles.
InheritedCollectionChannels(scope, collection string) ch.TimedSet
InheritedCollectionChannels(scope, collection string) (ch.TimedSet, error)

// Returns a TimedSet containing only the channels from the input set that the user has access
// to for the collection, annotated with the sequence number at which access was granted.
// Returns a string array containing any channels filtered out due to the user not having access
// to them.
FilterToAvailableCollectionChannels(scope, collection string, channels base.Set) (filtered ch.TimedSet, removed []string)
FilterToAvailableCollectionChannels(scope, collection string, channels base.Set) (filtered ch.TimedSet, removed []string, err error)

// If the input set contains the wildcard "*" channel, returns the user's inheritedChannels for the collection;
// else returns the input channel list unaltered.
expandCollectionWildCardChannel(scope, collection string, channels base.Set) base.Set
expandCollectionWildCardChannel(scope, collection string, channels base.Set) (base.Set, error)
}

// PrincipalCollectionAccess defines a common interface for principal access control. This interface is
Expand Down
12 changes: 9 additions & 3 deletions auth/collection_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ import (
// requireCanSeeCollectionChannels asserts that the principal can see all the specified channels in the given collection
func requireCanSeeCollectionChannels(t *testing.T, scope, collection string, princ Principal, channels ...string) {
for _, channel := range channels {
require.True(t, princ.CanSeeCollectionChannel(scope, collection, channel), "Expected %s to be able to see channel %q in %s.%s", princ.Name(), channel, scope, collection)
canSee, err := princ.CanSeeCollectionChannel(scope, collection, channel)
require.NoError(t, err)
require.True(t, canSee, "Expected %s to be able to see channel %q in %s.%s", princ.Name(), channel, scope, collection)
}
}

// requireCannotSeeCollectionChannels asserts that the principal cannot see any of the specified channels in the given collection
func requireCannotSeeCollectionChannels(t *testing.T, scope, collection string, princ Principal, channels ...string) {
for _, channel := range channels {
require.False(t, princ.CanSeeCollectionChannel(scope, collection, channel), "Expected %s to NOT be able to see channel %q in %s.%s", princ.Name(), channel, scope, collection)
canSee, err := princ.CanSeeCollectionChannel(scope, collection, channel)
require.NoError(t, err)
require.False(t, canSee, "Expected %s to NOT be able to see channel %q in %s.%s", princ.Name(), channel, scope, collection)
}
}

Expand Down Expand Up @@ -277,5 +281,7 @@ func TestPrincipalConfigSetExplicitChannels(t *testing.T) {

// requireExpandCollectionWildCardChannels asserts that the channels will be expanded to the expected channels for the given collection
func requireExpandCollectionWildCardChannels(t *testing.T, user User, scope, collection string, expectedChannels []string, channelsToExpand []string) {
require.Equal(t, base.SetFromArray(expectedChannels), user.expandCollectionWildCardChannel(scope, collection, base.SetFromArray(channelsToExpand)), "Expected channels %v for %s.%s from %v on user %s", expectedChannels, scope, collection, channelsToExpand, user.Name())
expandedChannels, err := user.expandCollectionWildCardChannel(scope, collection, base.SetFromArray(channelsToExpand))
require.NoError(t, err)
require.Equal(t, base.SetFromArray(expectedChannels), expandedChannels, "Expected channels %v for %s.%s from %v on user %s", expectedChannels, scope, collection, channelsToExpand, user.Name())
}
16 changes: 8 additions & 8 deletions auth/principal.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ type Principal interface {
SetSequence(sequence uint64)

// Returns true if the Principal has access to the given channel.
canSeeChannel(channel string) bool
canSeeChannel(channel string) (bool, error)

// If the Principal has access to the given channel, returns the sequence number at which
// access was granted; else returns zero.
canSeeChannelSince(channel string) uint64
canSeeChannelSince(channel string) (uint64, error)

// Returns an error if the Principal does not have access to all the channels in the set.
authorizeAllChannels(channels base.Set) error
Expand Down Expand Up @@ -104,7 +104,7 @@ type User interface {
SetPassword(password string) error

// GetRoles returns the set of roles the user belongs to, initializing them if necessary.
GetRoles() []Role
GetRoles() ([]Role, error)

// The set of Roles the user belongs to (including ones given to it by the sync function and by OIDC/JWT)
// Returns nil if invalidated
Expand Down Expand Up @@ -135,25 +135,25 @@ type User interface {

RoleHistory() TimedSetHistory

InitializeRoles()
InitializeRoles() error

revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels
revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) (RevokedChannels, error)

// Obtains the period over which the user had access to the given channel. Either directly or via a role.
channelGrantedPeriods(chanName string) ([]GrantHistorySequencePair, error)

// Every channel the user has access to, including those inherited from Roles.
inheritedChannels() ch.TimedSet
inheritedChannels() (ch.TimedSet, error)

// If the input set contains the wildcard "*" channel, returns the user's InheritedChannels;
// else returns the input channel list unaltered.
expandWildCardChannel(channels base.Set) base.Set
expandWildCardChannel(channels base.Set) (base.Set, error)

// Returns a TimedSet containing only the channels from the input set that the user has access
// to, annotated with the sequence number at which access was granted.
// Returns a string array containing any channels filtered out due to the user not having access
// to them.
filterToAvailableChannels(channels base.Set) (filtered ch.TimedSet, removed []string)
filterToAvailableChannels(channels base.Set) (filtered ch.TimedSet, removed []string, err error)

setRolesSince(ch.TimedSet)

Expand Down
20 changes: 14 additions & 6 deletions auth/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,17 +373,17 @@ func (role *roleImpl) UnauthError(err error) error {

// Returns true if the Role is allowed to access the channel.
// A nil Role means access control is disabled, so the function will return true.
func (role *roleImpl) canSeeChannel(channel string) bool {
return role == nil || role.Channels().Contains(channel) || role.Channels().Contains(ch.UserStarChannel)
func (role *roleImpl) canSeeChannel(channel string) (bool, error) {
return role == nil || role.Channels().Contains(channel) || role.Channels().Contains(ch.UserStarChannel), nil
}

// Returns the sequence number since which the Role has been able to access the channel, else zero.
func (role *roleImpl) canSeeChannelSince(channel string) uint64 {
func (role *roleImpl) canSeeChannelSince(channel string) (uint64, error) {
seq := role.Channels()[channel]
if seq.Sequence == 0 {
seq = role.Channels()[ch.UserStarChannel]
}
return seq.Sequence
return seq.Sequence, nil
}

func (role *roleImpl) authorizeAllChannels(channels base.Set) error {
Expand All @@ -399,7 +399,11 @@ func (role *roleImpl) authorizeAnyChannel(channels base.Set) error {
func authorizeAllChannels(princ Principal, channels base.Set) error {
var forbidden []string
for channel := range channels {
if !princ.canSeeChannel(channel) {
canSee, err := princ.canSeeChannel(channel)
if err != nil {
return err
}
if !canSee {
if forbidden == nil {
forbidden = make([]string, 0, len(channels))
}
Expand All @@ -417,7 +421,11 @@ func authorizeAllChannels(princ Principal, channels base.Set) error {
func authorizeAnyChannel(princ Principal, channels base.Set) error {
if len(channels) > 0 {
for channel := range channels {
if princ.canSeeChannel(channel) {
canSee, err := princ.canSeeChannel(channel)
if err != nil {
return err
}
if canSee {
return nil
}
}
Expand Down
14 changes: 7 additions & 7 deletions auth/role_collection_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,22 @@ func (role *roleImpl) SetCollectionChannelHistory(scope, collection string, hist

// Returns true if the Role is allowed to access the channel.
// A nil Role means access control is disabled, so the function will return true.
func (role *roleImpl) CanSeeCollectionChannel(scope, collection, channel string) bool {
func (role *roleImpl) CanSeeCollectionChannel(scope, collection, channel string) (bool, error) {
if base.IsDefaultCollection(scope, collection) {
return role.canSeeChannel(channel)
}

if role == nil {
return true
return true, nil
}
if cc, ok := role.getCollectionAccess(scope, collection); ok {
return cc.CanSeeChannel(channel)
return cc.CanSeeChannel(channel), nil
}
return false
return false, nil
}

// Returns the sequence number since which the Role has been able to access the channel, else zero.
func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel string) uint64 {
func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel string) (uint64, error) {
if base.IsDefaultCollection(scope, collection) {
return role.canSeeChannelSince(channel)
}
Expand All @@ -172,9 +172,9 @@ func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel st
if seq.Sequence == 0 {
seq = cc.Channels()[ch.UserStarChannel]
}
return seq.Sequence
return seq.Sequence, nil
}
return 0
return 0, nil
}

func (role *roleImpl) authorizeAllCollectionChannels(scope, collection string, channels base.Set) error {
Expand Down
Loading
Loading