@@ -18,10 +18,13 @@ package main
18
18
19
19
import (
20
20
"bytes"
21
+ "context"
22
+ "crypto/rand"
21
23
"encoding/json"
22
24
goerrors "errors"
23
25
"fmt"
24
26
"io"
27
+ "math/big"
25
28
"net/http"
26
29
"net/url"
27
30
"os"
@@ -31,18 +34,19 @@ import (
31
34
"syscall"
32
35
"time"
33
36
34
- "github.com/kserve/kserve/pkg/constants"
35
37
"github.com/pkg/errors"
36
-
38
+ flag "github.com/spf13/pflag"
37
39
"github.com/tidwall/gjson"
40
+ authnv1 "k8s.io/api/authentication/v1"
41
+ authzv1 "k8s.io/api/authorization/v1"
42
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
43
+ "k8s.io/client-go/kubernetes"
44
+ "k8s.io/client-go/rest"
38
45
logf "sigs.k8s.io/controller-runtime/pkg/log"
39
46
"sigs.k8s.io/controller-runtime/pkg/log/zap"
40
47
41
- "crypto/rand"
42
- "math/big"
43
-
44
48
"github.com/kserve/kserve/pkg/apis/serving/v1alpha1"
45
- flag "github.com/spf13/pflag "
49
+ "github.com/kserve/kserve/pkg/constants "
46
50
)
47
51
48
52
var log = logf .Log .WithName ("InferenceGraphRouter" )
@@ -411,10 +415,158 @@ func compilePatterns(patterns []string) ([]*regexp.Regexp, error) {
411
415
}
412
416
413
417
var (
418
+ enableAuthFlag = flag .Bool ("enable-auth" , false , "protect the inference graph with authorization" )
419
+ graphName = flag .String ("inferencegraph-name" , "" , "the name of the associated inference graph Kubernetes resource" )
414
420
jsonGraph = flag .String ("graph-json" , "" , "serialized json graph def" )
415
421
compiledHeaderPatterns []* regexp.Regexp
416
422
)
417
423
424
+ // findBearerToken parses the standard HTTP Authorization header to find and return
425
+ // a Bearer token that a client may have provided in the request. If the token
426
+ // is found, it is returned. Else, an empty string is returned and the HTTP response
427
+ // is sent to the client with proper status code and the reason for the request being
428
+ // rejected.
429
+ func findBearerToken (w http.ResponseWriter , r * http.Request ) string {
430
+ // Find for HTTP Authentication header. Reject request if not available.
431
+ authHeader := r .Header .Get ("Authorization" )
432
+ if len (authHeader ) == 0 {
433
+ w .Header ().Set ("X-Forbidden-Reason" , "No credentials were provided" )
434
+ w .WriteHeader (http .StatusUnauthorized )
435
+ return ""
436
+ }
437
+
438
+ // Parse Auth header
439
+ token := strings .TrimPrefix (authHeader , "Bearer " )
440
+ if token == authHeader {
441
+ w .Header ().Set ("X-Forbidden-Reason" , "Only Bearer tokens are supported" )
442
+ w .WriteHeader (http .StatusUnauthorized )
443
+ return ""
444
+ }
445
+ return token
446
+ }
447
+
448
+ // validateTokenIsAuthenticated queries the Kubernetes cluster to find if the provided token is
449
+ // valid and flagged as authenticated. If the token is usable, the result of the TokenReview
450
+ // is returned. Otherwise, the HTTP response is sent rejecting the request and setting
451
+ // a meaningful status code along with a reason (if available).
452
+ func validateTokenIsAuthenticated (w http.ResponseWriter , token string , clientset * kubernetes.Clientset ) * authnv1.TokenReview {
453
+ // Check the token is valid
454
+ tokenReview := authnv1.TokenReview {}
455
+ tokenReview .Spec .Token = token
456
+ tokenReviewResult , err := clientset .AuthenticationV1 ().TokenReviews ().Create (context .Background (), & tokenReview , metav1.CreateOptions {})
457
+ if err != nil {
458
+ log .Error (err , "failed to create TokenReview when verifying credentials" )
459
+ w .WriteHeader (http .StatusInternalServerError )
460
+ return nil
461
+ }
462
+ if len (tokenReviewResult .Status .Error ) != 0 {
463
+ w .Header ().Set ("X-Forbidden-Reason" , tokenReviewResult .Status .Error )
464
+ w .WriteHeader (http .StatusUnauthorized )
465
+ return nil
466
+ }
467
+ if ! tokenReviewResult .Status .Authenticated {
468
+ w .Header ().Set ("X-Forbidden-Reason" , "The provided token is unauthenticated" )
469
+ w .WriteHeader (http .StatusUnauthorized )
470
+ return nil
471
+ }
472
+ return tokenReviewResult
473
+ }
474
+
475
+ // checkRequestIsAuthorized verifies that the user in the provided tokenReviewResult has privileges to query the
476
+ // Kubernetes API and get the InferenceGraph resource that belongs to this pod. If so, the request is considered
477
+ // as allowed and `true` is returned. Otherwise, the HTTP response is sent rejecting the request and setting
478
+ // a meaningful status code along with a reason (if available).
479
+ func checkRequestIsAuthorized (w http.ResponseWriter , _ * http.Request , tokenReviewResult * authnv1.TokenReview , clientset * kubernetes.Clientset ) bool {
480
+ // Read pod namespace
481
+ const namespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
482
+ namespaceBytes , err := os .ReadFile (namespaceFile )
483
+ if err != nil {
484
+ log .Error (err , "failed to read namespace file while verifying credentials" )
485
+ w .WriteHeader (http .StatusInternalServerError )
486
+ return false
487
+ }
488
+ namespace := string (namespaceBytes )
489
+
490
+ // Check the subject is authorized to query the InferenceGraph
491
+ if len (* graphName ) == 0 {
492
+ log .Error (errors .New ("no graph name provided" ), "the --inferencegraph-name flag wasn't provided" )
493
+ w .WriteHeader (http .StatusInternalServerError )
494
+ return false
495
+ }
496
+ accessReview := authzv1.SubjectAccessReview {
497
+ Spec : authzv1.SubjectAccessReviewSpec {
498
+ ResourceAttributes : & authzv1.ResourceAttributes {
499
+ Namespace : namespace ,
500
+ Verb : "get" ,
501
+ Group : "serving.kserve.io" ,
502
+ Resource : "inferencegraphs" ,
503
+ Name : * graphName ,
504
+ },
505
+ User : tokenReviewResult .Status .User .Username ,
506
+ Groups : nil ,
507
+ },
508
+ }
509
+
510
+ accessReviewResult , err := clientset .AuthorizationV1 ().SubjectAccessReviews ().Create (context .Background (), & accessReview , metav1.CreateOptions {})
511
+ if err != nil {
512
+ log .Error (err , "failed to create LocalSubjectAccessReview when verifying credentials" )
513
+ w .WriteHeader (http .StatusInternalServerError )
514
+ return false
515
+ }
516
+ if accessReviewResult .Status .Allowed {
517
+ // Note: This is here so that the request is NOT allowed by default.
518
+ return true
519
+ }
520
+
521
+ w .Header ().Add ("X-Forbidden-Reason" , "Access to the InferenceGraph is not allowed" )
522
+ if len (accessReviewResult .Status .Reason ) != 0 {
523
+ w .Header ().Add ("X-Forbidden-Reason" , accessReviewResult .Status .Reason )
524
+ }
525
+ if len (accessReviewResult .Status .EvaluationError ) != 0 {
526
+ w .Header ().Add ("X-Forbidden-Reason" , accessReviewResult .Status .EvaluationError )
527
+ }
528
+
529
+ w .WriteHeader (http .StatusUnauthorized )
530
+ return false
531
+ }
532
+
533
+ // authMiddleware uses the Middleware pattern to protect the InferenceGraph behind authorization.
534
+ // It expects that a Bearer token is provided in the request in the standard HTTP Authorization
535
+ // header. The token is verified against Kubernetes using the TokenReview and SubjectAccessReview APIs.
536
+ // If the token is valid and has enough privileges, the handler provided in the `next` argument is run.
537
+ // Otherwise, `next` is not invoked and the reason for the rejection is sent in response headers.
538
+ func authMiddleware (next http.Handler ) (http.Handler , error ) {
539
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
540
+ k8sConfig , k8sConfigErr := rest .InClusterConfig ()
541
+ if k8sConfigErr != nil {
542
+ log .Error (k8sConfigErr , "failed to create rest configuration to connect to Kubernetes API" )
543
+ w .WriteHeader (http .StatusInternalServerError )
544
+ return
545
+ }
546
+
547
+ clientset , clientsetErr := kubernetes .NewForConfig (k8sConfig )
548
+ if clientsetErr != nil {
549
+ log .Error (k8sConfigErr , "failed to create Kubernetes client to connect to API" )
550
+ return
551
+ }
552
+
553
+ token := findBearerToken (w , r )
554
+ if len (token ) == 0 {
555
+ return
556
+ }
557
+
558
+ tokenReviewResult := validateTokenIsAuthenticated (w , token , clientset )
559
+ if tokenReviewResult == nil {
560
+ return
561
+ }
562
+
563
+ isAuthorized := checkRequestIsAuthorized (w , r , tokenReviewResult , clientset )
564
+ if isAuthorized {
565
+ next .ServeHTTP (w , r )
566
+ }
567
+ }), nil
568
+ }
569
+
418
570
func main () {
419
571
flag .Parse ()
420
572
logf .SetLogger (zap .New ())
@@ -434,14 +586,23 @@ func main() {
434
586
os .Exit (1 )
435
587
}
436
588
437
- http .HandleFunc ("/" , graphHandler )
589
+ var entrypointHandler http.Handler
590
+ entrypointHandler = http .HandlerFunc (graphHandler )
591
+ if * enableAuthFlag {
592
+ entrypointHandler , err = authMiddleware (entrypointHandler )
593
+ log .Info ("This Router has authorization enabled" )
594
+ if err != nil {
595
+ log .Error (err , "failed to create entrypoint handler" )
596
+ os .Exit (1 )
597
+ }
598
+ }
438
599
439
600
server := & http.Server {
440
- Addr : ":8080" , // specify the address and port
441
- Handler : http . HandlerFunc ( graphHandler ) , // specify your HTTP handler
442
- ReadTimeout : time .Minute , // set the maximum duration for reading the entire request, including the body
443
- WriteTimeout : time .Minute , // set the maximum duration before timing out writes of the response
444
- IdleTimeout : 3 * time .Minute , // set the maximum amount of time to wait for the next request when keep-alives are enabled
601
+ Addr : ":8080" , // specify the address and port
602
+ Handler : entrypointHandler , // specify your HTTP handler
603
+ ReadTimeout : time .Minute , // set the maximum duration for reading the entire request, including the body
604
+ WriteTimeout : time .Minute , // set the maximum duration before timing out writes of the response
605
+ IdleTimeout : 3 * time .Minute , // set the maximum amount of time to wait for the next request when keep-alives are enabled
445
606
}
446
607
err = server .ListenAndServe ()
447
608
0 commit comments