Skip to content

Commit 62dd952

Browse files
henryluohenryluo
andauthored
Implement Authorization Callback Function (#108)
* Implement Authorization Callback Function * Skip new client unit test if knox daemon is running * add tests for SetAccessCallback * check if knox daemon is running if on linux and using systemctl instead of pgrep * add tests for authorizeRequest * address nits and add testing for panic * modify requests to return a internal server error if the authorizeRequest returns error * run all top level tests, reset callback in tests * Address nits in PR for authorization callback * make static definition of expected error into testCase field --------- Co-authored-by: henryluo <henryluo@pinterest.com>
1 parent 9635b33 commit 62dd952

File tree

7 files changed

+252
-7
lines changed

7 files changed

+252
-7
lines changed

client_test.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package knox
22

33
import (
4+
"bytes"
45
"encoding/json"
56
"net/http"
67
"net/http/httptest"
78
"os"
9+
"os/exec"
810
"path"
911
"reflect"
12+
"runtime"
1013
"sync/atomic"
1114
"testing"
1215
)
@@ -75,7 +78,7 @@ func buildServer(code int, body []byte, a func(r *http.Request)) *httptest.Serve
7578
}))
7679
}
7780

78-
func buildConcurrentServer(code int, t *testing.T, a func(r *http.Request) []byte) *httptest.Server {
81+
func buildConcurrentServer(code int, a func(r *http.Request) []byte) *httptest.Server {
7982
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
8083
resp := a(r)
8184
w.WriteHeader(code)
@@ -84,6 +87,23 @@ func buildConcurrentServer(code int, t *testing.T, a func(r *http.Request) []byt
8487
}))
8588
}
8689

90+
func isKnoxDaemonRunning() bool {
91+
if runtime.GOOS != "linux" {
92+
return false
93+
}
94+
95+
cmd := exec.Command("systemctl", "is-active", "--quiet", "knox")
96+
97+
var out bytes.Buffer
98+
cmd.Stdout = &out
99+
err := cmd.Run()
100+
if err == nil {
101+
return true
102+
}
103+
104+
return false
105+
}
106+
87107
func TestGetKey(t *testing.T) {
88108
expected := Key{
89109
ID: "testkey",
@@ -357,7 +377,7 @@ func TestPutAccess(t *testing.T) {
357377

358378
func TestConcurrentDeletes(t *testing.T) {
359379
var ops uint64
360-
srv := buildConcurrentServer(200, t, func(r *http.Request) []byte {
380+
srv := buildConcurrentServer(200, func(r *http.Request) []byte {
361381
if r.Method != "DELETE" {
362382
t.Fatalf("%s is not DELETE", r.Method)
363383
}
@@ -511,6 +531,10 @@ func TestGetInvalidKeys(t *testing.T) {
511531
}
512532

513533
func TestNewFileClient(t *testing.T) {
534+
if isKnoxDaemonRunning() {
535+
t.Skip("Knox daemon is running, skipping the test.")
536+
}
537+
514538
_, err := NewFileClient("ThisKeyDoesNotExistSoWeExpectAnError")
515539
if (err.Error() != "error getting knox key ThisKeyDoesNotExistSoWeExpectAnError. error: exit status 1") && (err.Error() != "error getting knox key ThisKeyDoesNotExistSoWeExpectAnError. error: exec: \"knox\": executable file not found in $PATH") {
516540
t.Fatal("Unexpected error", err.Error())

knox.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,14 @@ type Principal interface {
508508
CanAccess(ACL, AccessType) bool
509509
GetID() string
510510
Type() string
511+
Raw() []RawPrincipal
512+
}
513+
514+
// RawPrincipal is a serializable version of a principal for passing to
515+
// access callbacks.
516+
type RawPrincipal struct {
517+
ID string `json:"id"`
518+
Type string `json:"type"`
511519
}
512520

513521
// PrincipalMux provides a Principal Interface over multiple Principals.
@@ -564,6 +572,15 @@ func (p PrincipalMux) Default() Principal {
564572
return p.defaultPrincipal
565573
}
566574

575+
// Raw returns the raw version of all the principals.
576+
func (p PrincipalMux) Raw() []RawPrincipal {
577+
raw := []RawPrincipal{}
578+
for _, principal := range p.allPrincipals {
579+
raw = append(raw, principal.Raw()...)
580+
}
581+
return raw
582+
}
583+
567584
// NewPrincipalMux returns a Principal that represents many principals.
568585
func NewPrincipalMux(defaultPrincipal Principal, allPrincipals map[string]Principal) Principal {
569586
return PrincipalMux{
@@ -599,3 +616,10 @@ type Response struct {
599616
Message string `json:"message"`
600617
Data interface{} `json:"data"`
601618
}
619+
620+
// AccessCallbackInput is the input to the access callback function.
621+
type AccessCallbackInput struct {
622+
Key Key `json:"key"`
623+
Principals []RawPrincipal `json:"principals"`
624+
AccessType AccessType `json:"access_type"`
625+
}

server/api.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ func AddDefaultAccess(a *knox.Access) {
294294
defaultAccess = append(defaultAccess, *a)
295295
}
296296

297+
var accessCallback func(knox.AccessCallbackInput) (bool, error)
298+
299+
// SetAccessCallback adds a callback.
300+
func SetAccessCallback(callback func(knox.AccessCallbackInput) (bool, error)) {
301+
accessCallback = callback
302+
}
303+
297304
// Extra validators to apply on principals submitted to Knox.
298305
var extraPrincipalValidators []knox.PrincipalValidator
299306

server/api_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ func additionalMockHandler(m KeyManager, principal knox.Principal, parameters ma
4949
return "The meaning of life is 42", nil
5050
}
5151

52+
func mockAccessCallback(input knox.AccessCallbackInput) (bool, error) {
53+
return true, nil
54+
}
55+
5256
func mockRoute() Route {
5357
return Route{
5458
Method: "GET",
@@ -105,6 +109,27 @@ func TestAddDefaultAccess(t *testing.T) {
105109

106110
}
107111

112+
func TestSetAccessCallback(t *testing.T) {
113+
defer SetAccessCallback(nil)
114+
115+
SetAccessCallback(mockAccessCallback)
116+
117+
input := knox.AccessCallbackInput{}
118+
119+
if accessCallback == nil {
120+
t.Fatal("accessCallback should not be nil")
121+
}
122+
123+
canAccess, err := accessCallback(input)
124+
if err != nil {
125+
t.Fatal("accessCallback should not return an error")
126+
}
127+
128+
if !canAccess {
129+
t.Fatal("accessCallback should return true")
130+
}
131+
}
132+
108133
func TestParseFormParameter(t *testing.T) {
109134
p := PostParameter("key")
110135

server/auth/auth.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,15 @@ func (u user) GetID() string {
380380
return u.ID
381381
}
382382

383+
func (u user) Raw() []knox.RawPrincipal {
384+
return []knox.RawPrincipal{
385+
{
386+
ID: u.GetID(),
387+
Type: u.Type(),
388+
},
389+
}
390+
}
391+
383392
// Type returns the underlying type of a principal, for logging/debugging purposes.
384393
func (u user) Type() string {
385394
return "user"
@@ -415,6 +424,15 @@ func (m machine) Type() string {
415424
return "machine"
416425
}
417426

427+
func (m machine) Raw() []knox.RawPrincipal {
428+
return []knox.RawPrincipal{
429+
{
430+
ID: m.GetID(),
431+
Type: m.Type(),
432+
},
433+
}
434+
}
435+
418436
// CanAccess determines if a Machine can access an object represented by the ACL
419437
// with a certain AccessType. It compares Machine hostname and hostname prefix.
420438
func (m machine) CanAccess(acl knox.ACL, t knox.AccessType) bool {
@@ -450,6 +468,15 @@ func (s service) Type() string {
450468
return "service"
451469
}
452470

471+
func (s service) Raw() []knox.RawPrincipal {
472+
return []knox.RawPrincipal{
473+
{
474+
ID: s.GetID(),
475+
Type: s.Type(),
476+
},
477+
}
478+
}
479+
453480
// CanAccess determines if a Service can access an object represented by the ACL
454481
// with a certain AccessType. It compares Service id and id prefix.
455482
func (s service) CanAccess(acl knox.ACL, t knox.AccessType) bool {

server/routes.go

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"strconv"
99

1010
"github.com/pinterest/knox"
11+
"github.com/pinterest/knox/log"
1112
"github.com/pinterest/knox/server/auth"
1213
)
1314

@@ -211,9 +212,15 @@ func getKeyHandler(m KeyManager, principal knox.Principal, parameters map[string
211212
}
212213

213214
// Authorize access to data
214-
if !principal.CanAccess(key.ACL, knox.Read) {
215+
authorized, authzErr := authorizeRequest(key, principal, knox.Read)
216+
if authzErr != nil {
217+
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
218+
}
219+
220+
if !authorized {
215221
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to read %s", principal.GetID(), keyID))
216222
}
223+
217224
// Zero ACL for key response, in order to avoid caching unnecessarily
218225
key.ACL = knox.ACL{}
219226
return key, nil
@@ -234,7 +241,12 @@ func deleteKeyHandler(m KeyManager, principal knox.Principal, parameters map[str
234241
}
235242

236243
// Authorize
237-
if !principal.CanAccess(key.ACL, knox.Admin) {
244+
authorized, authzErr := authorizeRequest(key, principal, knox.Admin)
245+
if authzErr != nil {
246+
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
247+
}
248+
249+
if !authorized {
238250
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to delete %s", principal.GetID(), keyID))
239251
}
240252

@@ -314,7 +326,12 @@ func putAccessHandler(m KeyManager, principal knox.Principal, parameters map[str
314326
}
315327

316328
// Authorize
317-
if !principal.CanAccess(key.ACL, knox.Admin) {
329+
authorized, authzErr := authorizeRequest(key, principal, knox.Admin)
330+
if authzErr != nil {
331+
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
332+
}
333+
334+
if !authorized {
318335
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to update access for %s", principal.GetID(), keyID))
319336
}
320337

@@ -371,7 +388,12 @@ func postVersionHandler(m KeyManager, principal knox.Principal, parameters map[s
371388
}
372389

373390
// Authorize
374-
if !principal.CanAccess(key.ACL, knox.Write) {
391+
authorized, authzErr := authorizeRequest(key, principal, knox.Write)
392+
if authzErr != nil {
393+
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
394+
}
395+
396+
if !authorized {
375397
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to write %s", principal.GetID(), keyID))
376398
}
377399

@@ -428,7 +450,12 @@ func putVersionsHandler(m KeyManager, principal knox.Principal, parameters map[s
428450
}
429451

430452
// Authorize
431-
if !principal.CanAccess(key.ACL, knox.Write) {
453+
authorized, authzErr := authorizeRequest(key, principal, knox.Write)
454+
if authzErr != nil {
455+
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
456+
}
457+
458+
if !authorized {
432459
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to write %s", principal.GetID(), keyID))
433460
}
434461

@@ -445,3 +472,25 @@ func putVersionsHandler(m KeyManager, principal knox.Principal, parameters map[s
445472
return nil, errF(knox.InternalServerErrorCode, err.Error())
446473
}
447474
}
475+
476+
func authorizeRequest(key *knox.Key, principal knox.Principal, access knox.AccessType) (allow bool, err error) {
477+
defer func() {
478+
if r := recover(); r != nil {
479+
log.Printf("Recovered from panic in access callback: %v", r)
480+
481+
err = fmt.Errorf("Recovered from panic in access callback: %v", r)
482+
}
483+
}()
484+
485+
allow = principal.CanAccess(key.ACL, access)
486+
487+
if !allow && accessCallback != nil {
488+
allow, err = accessCallback(knox.AccessCallbackInput{
489+
Key: *key,
490+
Principals: principal.Raw(),
491+
AccessType: access,
492+
})
493+
}
494+
495+
return
496+
}

0 commit comments

Comments
 (0)