Skip to content

Commit b2b599f

Browse files
authored
Implement authorization for Raw InferenceGraphs (#499)
* Implement authorization for Raw InferenceGraphs Authorization is implemented by using the TokenReview and the SubjectAccessReview Kubernetes APIs. A Middleware function is setup when some arguments are specified that trigger plugging-in the middleware func. Some additional reconciliation is added toInferenceGraph controller to: * Switch to a different ServiceAccount so that privileges for using the cluster APIs are granted. * Creating the needed ServiceAccount for the auth-protected InferenceGraph to run. * Managing a ClusterRoleBinding to give the required privileges for auth verification. Signed-off-by: Edgar Hernández <[email protected]> * Feedback: Jooho Fix comment * Fix unit test Signed-off-by: Edgar Hernández <[email protected]> --------- Signed-off-by: Edgar Hernández <[email protected]>
1 parent 8837d5f commit b2b599f

File tree

9 files changed

+606
-22
lines changed

9 files changed

+606
-22
lines changed

charts/kserve-resources/templates/clusterrole.yaml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,17 @@ rules:
5151
- ""
5252
resources:
5353
- secrets
54+
verbs:
55+
- get
56+
- apiGroups:
57+
- ""
58+
resources:
5459
- serviceaccounts
5560
verbs:
61+
- create
62+
- delete
5663
- get
64+
- patch
5765
- apiGroups:
5866
- admissionregistration.k8s.io
5967
resources:
@@ -124,6 +132,34 @@ rules:
124132
- patch
125133
- update
126134
- watch
135+
- apiGroups:
136+
- rbac.authorization.k8s.io
137+
resourceNames:
138+
- kserve-inferencegraph-auth-verifiers
139+
resources:
140+
- clusterrolebindings
141+
verbs:
142+
- create
143+
- get
144+
- patch
145+
- update
146+
- apiGroups:
147+
- route.openshift.io
148+
resources:
149+
- routes
150+
verbs:
151+
- create
152+
- get
153+
- list
154+
- patch
155+
- update
156+
- watch
157+
- apiGroups:
158+
- route.openshift.io
159+
resources:
160+
- routes/status
161+
verbs:
162+
- get
127163
- apiGroups:
128164
- serving.knative.dev
129165
resources:

cmd/router/main.go

Lines changed: 173 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@ package main
1818

1919
import (
2020
"bytes"
21+
"context"
22+
"crypto/rand"
2123
"encoding/json"
2224
goerrors "errors"
2325
"fmt"
2426
"io"
27+
"math/big"
2528
"net/http"
2629
"net/url"
2730
"os"
@@ -31,18 +34,19 @@ import (
3134
"syscall"
3235
"time"
3336

34-
"github.com/kserve/kserve/pkg/constants"
3537
"github.com/pkg/errors"
36-
38+
flag "github.com/spf13/pflag"
3739
"github.com/tidwall/gjson"
40+
authnv1 "k8s.io/api/authentication/v1"
41+
authzv1 "k8s.io/api/authorization/v1"
42+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
43+
"k8s.io/client-go/kubernetes"
44+
"k8s.io/client-go/rest"
3845
logf "sigs.k8s.io/controller-runtime/pkg/log"
3946
"sigs.k8s.io/controller-runtime/pkg/log/zap"
4047

41-
"crypto/rand"
42-
"math/big"
43-
4448
"github.com/kserve/kserve/pkg/apis/serving/v1alpha1"
45-
flag "github.com/spf13/pflag"
49+
"github.com/kserve/kserve/pkg/constants"
4650
)
4751

4852
var log = logf.Log.WithName("InferenceGraphRouter")
@@ -411,10 +415,158 @@ func compilePatterns(patterns []string) ([]*regexp.Regexp, error) {
411415
}
412416

413417
var (
418+
enableAuthFlag = flag.Bool("enable-auth", false, "protect the inference graph with authorization")
419+
graphName = flag.String("inferencegraph-name", "", "the name of the associated inference graph Kubernetes resource")
414420
jsonGraph = flag.String("graph-json", "", "serialized json graph def")
415421
compiledHeaderPatterns []*regexp.Regexp
416422
)
417423

424+
// findBearerToken parses the standard HTTP Authorization header to find and return
425+
// a Bearer token that a client may have provided in the request. If the token
426+
// is found, it is returned. Else, an empty string is returned and the HTTP response
427+
// is sent to the client with proper status code and the reason for the request being
428+
// rejected.
429+
func findBearerToken(w http.ResponseWriter, r *http.Request) string {
430+
// Find for HTTP Authentication header. Reject request if not available.
431+
authHeader := r.Header.Get("Authorization")
432+
if len(authHeader) == 0 {
433+
w.Header().Set("X-Forbidden-Reason", "No credentials were provided")
434+
w.WriteHeader(http.StatusUnauthorized)
435+
return ""
436+
}
437+
438+
// Parse Auth header
439+
token := strings.TrimPrefix(authHeader, "Bearer ")
440+
if token == authHeader {
441+
w.Header().Set("X-Forbidden-Reason", "Only Bearer tokens are supported")
442+
w.WriteHeader(http.StatusUnauthorized)
443+
return ""
444+
}
445+
return token
446+
}
447+
448+
// validateTokenIsAuthenticated queries the Kubernetes cluster to find if the provided token is
449+
// valid and flagged as authenticated. If the token is usable, the result of the TokenReview
450+
// is returned. Otherwise, the HTTP response is sent rejecting the request and setting
451+
// a meaningful status code along with a reason (if available).
452+
func validateTokenIsAuthenticated(w http.ResponseWriter, token string, clientset *kubernetes.Clientset) *authnv1.TokenReview {
453+
// Check the token is valid
454+
tokenReview := authnv1.TokenReview{}
455+
tokenReview.Spec.Token = token
456+
tokenReviewResult, err := clientset.AuthenticationV1().TokenReviews().Create(context.Background(), &tokenReview, metav1.CreateOptions{})
457+
if err != nil {
458+
log.Error(err, "failed to create TokenReview when verifying credentials")
459+
w.WriteHeader(http.StatusInternalServerError)
460+
return nil
461+
}
462+
if len(tokenReviewResult.Status.Error) != 0 {
463+
w.Header().Set("X-Forbidden-Reason", tokenReviewResult.Status.Error)
464+
w.WriteHeader(http.StatusUnauthorized)
465+
return nil
466+
}
467+
if !tokenReviewResult.Status.Authenticated {
468+
w.Header().Set("X-Forbidden-Reason", "The provided token is unauthenticated")
469+
w.WriteHeader(http.StatusUnauthorized)
470+
return nil
471+
}
472+
return tokenReviewResult
473+
}
474+
475+
// checkRequestIsAuthorized verifies that the user in the provided tokenReviewResult has privileges to query the
476+
// Kubernetes API and get the InferenceGraph resource that belongs to this pod. If so, the request is considered
477+
// as allowed and `true` is returned. Otherwise, the HTTP response is sent rejecting the request and setting
478+
// a meaningful status code along with a reason (if available).
479+
func checkRequestIsAuthorized(w http.ResponseWriter, _ *http.Request, tokenReviewResult *authnv1.TokenReview, clientset *kubernetes.Clientset) bool {
480+
// Read pod namespace
481+
const namespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
482+
namespaceBytes, err := os.ReadFile(namespaceFile)
483+
if err != nil {
484+
log.Error(err, "failed to read namespace file while verifying credentials")
485+
w.WriteHeader(http.StatusInternalServerError)
486+
return false
487+
}
488+
namespace := string(namespaceBytes)
489+
490+
// Check the subject is authorized to query the InferenceGraph
491+
if len(*graphName) == 0 {
492+
log.Error(errors.New("no graph name provided"), "the --inferencegraph-name flag wasn't provided")
493+
w.WriteHeader(http.StatusInternalServerError)
494+
return false
495+
}
496+
accessReview := authzv1.SubjectAccessReview{
497+
Spec: authzv1.SubjectAccessReviewSpec{
498+
ResourceAttributes: &authzv1.ResourceAttributes{
499+
Namespace: namespace,
500+
Verb: "get",
501+
Group: "serving.kserve.io",
502+
Resource: "inferencegraphs",
503+
Name: *graphName,
504+
},
505+
User: tokenReviewResult.Status.User.Username,
506+
Groups: nil,
507+
},
508+
}
509+
510+
accessReviewResult, err := clientset.AuthorizationV1().SubjectAccessReviews().Create(context.Background(), &accessReview, metav1.CreateOptions{})
511+
if err != nil {
512+
log.Error(err, "failed to create LocalSubjectAccessReview when verifying credentials")
513+
w.WriteHeader(http.StatusInternalServerError)
514+
return false
515+
}
516+
if accessReviewResult.Status.Allowed {
517+
// Note: This is here so that the request is NOT allowed by default.
518+
return true
519+
}
520+
521+
w.Header().Add("X-Forbidden-Reason", "Access to the InferenceGraph is not allowed")
522+
if len(accessReviewResult.Status.Reason) != 0 {
523+
w.Header().Add("X-Forbidden-Reason", accessReviewResult.Status.Reason)
524+
}
525+
if len(accessReviewResult.Status.EvaluationError) != 0 {
526+
w.Header().Add("X-Forbidden-Reason", accessReviewResult.Status.EvaluationError)
527+
}
528+
529+
w.WriteHeader(http.StatusUnauthorized)
530+
return false
531+
}
532+
533+
// authMiddleware uses the Middleware pattern to protect the InferenceGraph behind authorization.
534+
// It expects that a Bearer token is provided in the request in the standard HTTP Authorization
535+
// header. The token is verified against Kubernetes using the TokenReview and SubjectAccessReview APIs.
536+
// If the token is valid and has enough privileges, the handler provided in the `next` argument is run.
537+
// Otherwise, `next` is not invoked and the reason for the rejection is sent in response headers.
538+
func authMiddleware(next http.Handler) (http.Handler, error) {
539+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
540+
k8sConfig, k8sConfigErr := rest.InClusterConfig()
541+
if k8sConfigErr != nil {
542+
log.Error(k8sConfigErr, "failed to create rest configuration to connect to Kubernetes API")
543+
w.WriteHeader(http.StatusInternalServerError)
544+
return
545+
}
546+
547+
clientset, clientsetErr := kubernetes.NewForConfig(k8sConfig)
548+
if clientsetErr != nil {
549+
log.Error(k8sConfigErr, "failed to create Kubernetes client to connect to API")
550+
return
551+
}
552+
553+
token := findBearerToken(w, r)
554+
if len(token) == 0 {
555+
return
556+
}
557+
558+
tokenReviewResult := validateTokenIsAuthenticated(w, token, clientset)
559+
if tokenReviewResult == nil {
560+
return
561+
}
562+
563+
isAuthorized := checkRequestIsAuthorized(w, r, tokenReviewResult, clientset)
564+
if isAuthorized {
565+
next.ServeHTTP(w, r)
566+
}
567+
}), nil
568+
}
569+
418570
func main() {
419571
flag.Parse()
420572
logf.SetLogger(zap.New())
@@ -434,14 +586,23 @@ func main() {
434586
os.Exit(1)
435587
}
436588

437-
http.HandleFunc("/", graphHandler)
589+
var entrypointHandler http.Handler
590+
entrypointHandler = http.HandlerFunc(graphHandler)
591+
if *enableAuthFlag {
592+
entrypointHandler, err = authMiddleware(entrypointHandler)
593+
log.Info("This Router has authorization enabled")
594+
if err != nil {
595+
log.Error(err, "failed to create entrypoint handler")
596+
os.Exit(1)
597+
}
598+
}
438599

439600
server := &http.Server{
440-
Addr: ":8080", // specify the address and port
441-
Handler: http.HandlerFunc(graphHandler), // specify your HTTP handler
442-
ReadTimeout: time.Minute, // set the maximum duration for reading the entire request, including the body
443-
WriteTimeout: time.Minute, // set the maximum duration before timing out writes of the response
444-
IdleTimeout: 3 * time.Minute, // set the maximum amount of time to wait for the next request when keep-alives are enabled
601+
Addr: ":8080", // specify the address and port
602+
Handler: entrypointHandler, // specify your HTTP handler
603+
ReadTimeout: time.Minute, // set the maximum duration for reading the entire request, including the body
604+
WriteTimeout: time.Minute, // set the maximum duration before timing out writes of the response
605+
IdleTimeout: 3 * time.Minute, // set the maximum amount of time to wait for the next request when keep-alives are enabled
445606
}
446607
err = server.ListenAndServe()
447608

config/rbac/role.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,17 @@ rules:
3838
- ""
3939
resources:
4040
- secrets
41+
verbs:
42+
- get
43+
- apiGroups:
44+
- ""
45+
resources:
4146
- serviceaccounts
4247
verbs:
48+
- create
49+
- delete
4350
- get
51+
- patch
4452
- apiGroups:
4553
- admissionregistration.k8s.io
4654
resources:
@@ -111,6 +119,17 @@ rules:
111119
- patch
112120
- update
113121
- watch
122+
- apiGroups:
123+
- rbac.authorization.k8s.io
124+
resourceNames:
125+
- kserve-inferencegraph-auth-verifiers
126+
resources:
127+
- clusterrolebindings
128+
verbs:
129+
- create
130+
- get
131+
- patch
132+
- update
114133
- apiGroups:
115134
- route.openshift.io
116135
resources:

pkg/constants/constants.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ var (
5353
const (
5454
RouterHeadersPropagateEnvVar = "PROPAGATE_HEADERS"
5555
InferenceGraphLabel = "serving.kserve.io/inferencegraph"
56+
InferenceGraphAuthCRBName = "kserve-inferencegraph-auth-verifiers"
57+
InferenceGraphFinalizerName = "inferencegraph.finalizers"
5658
)
5759

5860
// TrainedModel Constants

0 commit comments

Comments
 (0)