Skip to content

Commit 83697b0

Browse files
authored
fix(server): respond with forbidden if failed to authenticate (#4200)
Signed-off-by: Aljoscha Bollmann <aljoscha.bollmann@proton.me>
1 parent 25591ee commit 83697b0

File tree

5 files changed

+61
-4
lines changed

5 files changed

+61
-4
lines changed

connector/connector.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,22 @@ package connector
33

44
import (
55
"context"
6+
"fmt"
67
"net/http"
78
)
89

10+
// UserNotInRequiredGroupsError is returned by a connector when a user
11+
// successfully authenticates but is not a member of any of the required groups.
12+
// The server will respond with HTTP 403 Forbidden instead of 500.
13+
type UserNotInRequiredGroupsError struct {
14+
UserID string
15+
Groups []string
16+
}
17+
18+
func (e *UserNotInRequiredGroupsError) Error() string {
19+
return fmt.Sprintf("user %q is not in any of the required groups %v", e.UserID, e.Groups)
20+
}
21+
922
// Connector is a mechanism for federating login to a remote identity service.
1023
//
1124
// Implementations are expected to implement either the PasswordConnector or

connector/microsoft/microsoft.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ func (c *microsoftConnector) HandleCallback(s connector.Scopes, connData []byte,
227227
if c.groupsRequired(s.Groups) {
228228
groups, err := c.getGroups(ctx, client, user.ID)
229229
if err != nil {
230-
return identity, fmt.Errorf("microsoft: get groups: %v", err)
230+
return identity, fmt.Errorf("microsoft: get groups: %w", err)
231231
}
232232
identity.Groups = groups
233233
}
@@ -318,7 +318,7 @@ func (c *microsoftConnector) Refresh(ctx context.Context, s connector.Scopes, id
318318
if c.groupsRequired(s.Groups) {
319319
groups, err := c.getGroups(ctx, client, user.ID)
320320
if err != nil {
321-
return identity, fmt.Errorf("microsoft: get groups: %v", err)
321+
return identity, fmt.Errorf("microsoft: get groups: %w", err)
322322
}
323323
identity.Groups = groups
324324
}
@@ -404,7 +404,7 @@ func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client,
404404
// ensure that the user is in at least one required group
405405
filteredGroups := groups_pkg.Filter(userGroups, c.groups)
406406
if len(c.groups) > 0 && len(filteredGroups) == 0 {
407-
return nil, fmt.Errorf("microsoft: user %v not in any of the required groups", userID)
407+
return nil, &connector.UserNotInRequiredGroupsError{UserID: userID, Groups: c.groups}
408408
} else if c.useGroupsAsWhitelist {
409409
return filteredGroups, nil
410410
}

connector/microsoft/microsoft_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package microsoft
22

33
import (
44
"encoding/json"
5+
"errors"
56
"fmt"
67
"net/http"
78
"net/http/httptest"
@@ -119,6 +120,39 @@ func TestUserGroupsFromGraphAPI(t *testing.T) {
119120
expectEquals(t, identity.Groups, []string{"a", "b"})
120121
}
121122

123+
func TestUserNotInRequiredGroupFromGraphAPI(t *testing.T) {
124+
s := newTestServer(map[string]testResponse{
125+
"/v1.0/me?$select=id,displayName,userPrincipalName": {
126+
data: user{ID: "user-id-123", Name: "Jane Doe", Email: "jane.doe@example.com"},
127+
},
128+
// The user is a member of groups "c" and "d", but the connector only
129+
// allows group "a" — so the user should be denied.
130+
"/v1.0/me/getMemberGroups": {data: map[string]interface{}{
131+
"value": []string{"c", "d"},
132+
}},
133+
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
134+
})
135+
defer s.Close()
136+
137+
req, _ := http.NewRequest("GET", s.URL, nil)
138+
139+
c := microsoftConnector{
140+
apiURL: s.URL,
141+
graphURL: s.URL,
142+
tenant: tenant,
143+
groups: []string{"a"},
144+
}
145+
_, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
146+
if err == nil {
147+
t.Fatal("expected error when user is not in any required group, got nil")
148+
}
149+
150+
var groupsErr *connector.UserNotInRequiredGroupsError
151+
if !errors.As(err, &groupsErr) {
152+
t.Errorf("expected *connector.UserNotInRequiredGroupsError, got %T: %v", err, err)
153+
}
154+
}
155+
122156
func newTestServer(responses map[string]testResponse) *httptest.Server {
123157
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124158
response, found := responses[r.RequestURI]

server/errors.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,8 @@ const (
2323

2424
// ErrMsgMethodNotAllowed is shown when an unsupported HTTP method is used.
2525
ErrMsgMethodNotAllowed = "Method not allowed."
26+
27+
// ErrMsgNotInRequiredGroups is shown when a user authenticates successfully
28+
// but is not a member of any of the groups required by the connector.
29+
ErrMsgNotInRequiredGroups = "You are not a member of any of the required groups to authenticate."
2630
)

server/handlers.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/subtle"
88
"encoding/base64"
99
"encoding/json"
10+
"errors"
1011
"fmt"
1112
"html/template"
1213
"net/http"
@@ -499,7 +500,12 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
499500

500501
if err != nil {
501502
s.logger.ErrorContext(r.Context(), "failed to authenticate", "err", err)
502-
s.renderError(r, w, http.StatusInternalServerError, ErrMsgAuthenticationFailed)
503+
var groupsErr *connector.UserNotInRequiredGroupsError
504+
if errors.As(err, &groupsErr) {
505+
s.renderError(r, w, http.StatusForbidden, ErrMsgNotInRequiredGroups)
506+
} else {
507+
s.renderError(r, w, http.StatusInternalServerError, ErrMsgAuthenticationFailed)
508+
}
503509
return
504510
}
505511

0 commit comments

Comments
 (0)