Skip to content

Commit afd1edf

Browse files
authored
feat: check token before fetching docs if protected (#218)
* feat: check token before fetching docs if protected On-behalf-of: @SAP [email protected] Signed-off-by: Artem Shcherbatiuk <[email protected]> --------- Signed-off-by: Artem Shcherbatiuk <[email protected]>
1 parent ad7c35a commit afd1edf

File tree

9 files changed

+572
-270
lines changed

9 files changed

+572
-270
lines changed

cmd/root.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,14 @@ func initConfig() {
5151
v.SetDefault("openapi-definitions-path", "./bin/definitions")
5252
v.SetDefault("enable-kcp", true)
5353
v.SetDefault("local-development", false)
54+
v.SetDefault("authenticate-schema-requests", false)
5455

5556
// Listener
5657
v.SetDefault("listener-apiexport-workspace", ":root")
5758
v.SetDefault("listener-apiexport-name", "kcp.io")
5859

5960
// Gateway
60-
v.SetDefault("gateway-port", "9080")
61+
v.SetDefault("gateway-port", "8080")
6162
v.SetDefault("gateway-username-claim", "email")
6263
v.SetDefault("gateway-should-impersonate", true)
6364
// Gateway Handler config

common/config/config.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package config
22

33
type Config struct {
4-
OpenApiDefinitionsPath string `mapstructure:"openapi-definitions-path"`
5-
EnableKcp bool `mapstructure:"enable-kcp"`
6-
LocalDevelopment bool `mapstructure:"local-development"`
4+
OpenApiDefinitionsPath string `mapstructure:"openapi-definitions-path"`
5+
EnableKcp bool `mapstructure:"enable-kcp"`
6+
LocalDevelopment bool `mapstructure:"local-development"`
7+
AuthenticateSchemaRequests bool `mapstructure:"authenticate-schema-requests"`
78

89
Listener struct {
910
// Listener fields will be added here

gateway/manager/export_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package manager
2+
3+
import (
4+
"github.com/openmfp/golang-commons/logger/testlogger"
5+
appConfig "github.com/openmfp/kubernetes-graphql-gateway/common/config"
6+
)
7+
8+
func NewManagerForTest() *Service {
9+
cfg := appConfig.Config{}
10+
cfg.Gateway.Cors.Enabled = true
11+
cfg.Gateway.Cors.AllowedOrigins = []string{"*"}
12+
cfg.Gateway.Cors.AllowedHeaders = []string{"Authorization"}
13+
14+
s := &Service{
15+
AppCfg: cfg,
16+
handlers: handlerStore{registry: make(map[string]*graphqlHandler)},
17+
log: testlogger.New().HideLogOutput().Logger,
18+
resolver: nil,
19+
}
20+
s.handlers.registry["testws"] = &graphqlHandler{}
21+
22+
return s
23+
}

gateway/manager/handler.go

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
package manager
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"strings"
12+
"sync"
13+
14+
"github.com/graphql-go/graphql"
15+
"github.com/graphql-go/handler"
16+
"github.com/kcp-dev/logicalcluster/v3"
17+
"k8s.io/client-go/rest"
18+
"sigs.k8s.io/controller-runtime/pkg/kontext"
19+
20+
"github.com/openmfp/golang-commons/sentry"
21+
)
22+
23+
var (
24+
ErrNoHandlerFound = errors.New("no handler found for workspace")
25+
)
26+
27+
type handlerStore struct {
28+
mu sync.RWMutex
29+
registry map[string]*graphqlHandler
30+
}
31+
32+
type graphqlHandler struct {
33+
schema *graphql.Schema
34+
handler http.Handler
35+
}
36+
37+
func (s *Service) createHandler(schema *graphql.Schema) *graphqlHandler {
38+
h := handler.New(&handler.Config{
39+
Schema: schema,
40+
Pretty: s.AppCfg.Gateway.HandlerCfg.Pretty,
41+
Playground: s.AppCfg.Gateway.HandlerCfg.Playground,
42+
GraphiQL: s.AppCfg.Gateway.HandlerCfg.GraphiQL,
43+
})
44+
return &graphqlHandler{
45+
schema: schema,
46+
handler: h,
47+
}
48+
}
49+
50+
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
51+
if s.handleCORS(w, r) {
52+
return
53+
}
54+
55+
workspace, h, ok := s.getWorkspaceAndHandler(w, r)
56+
if !ok {
57+
return
58+
}
59+
60+
if r.Method == http.MethodGet {
61+
h.handler.ServeHTTP(w, r)
62+
return
63+
}
64+
65+
token := getToken(r)
66+
67+
if !s.handleAuth(w, r, token) {
68+
return
69+
}
70+
71+
s.setContexts(r, workspace, token)
72+
73+
if r.Header.Get("Accept") == "text/event-stream" {
74+
s.handleSubscription(w, r, h.schema)
75+
} else {
76+
h.handler.ServeHTTP(w, r)
77+
}
78+
}
79+
80+
func (s *Service) handleCORS(w http.ResponseWriter, r *http.Request) bool {
81+
if s.AppCfg.Gateway.Cors.Enabled {
82+
allowedOrigins := strings.Join(s.AppCfg.Gateway.Cors.AllowedOrigins, ",")
83+
allowedHeaders := strings.Join(s.AppCfg.Gateway.Cors.AllowedHeaders, ",")
84+
w.Header().Set("Access-Control-Allow-Origin", allowedOrigins)
85+
w.Header().Set("Access-Control-Allow-Headers", allowedHeaders)
86+
// setting cors allowed methods is not needed for this service,
87+
// as all graphql methods are part of the cors safelisted methods
88+
// https://fetch.spec.whatwg.org/#cors-safelisted-method
89+
90+
if r.Method == http.MethodOptions {
91+
w.WriteHeader(http.StatusOK)
92+
return true
93+
}
94+
}
95+
return false
96+
}
97+
98+
// getWorkspaceAndHandler extracts the workspace from the path, finds the handler, and handles errors.
99+
// Returns workspace, handler, and ok (true if found, false if error was handled).
100+
func (s *Service) getWorkspaceAndHandler(w http.ResponseWriter, r *http.Request) (string, *graphqlHandler, bool) {
101+
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
102+
if len(parts) != 2 {
103+
s.log.Error().Err(fmt.Errorf("invalid path")).Str("path", r.URL.Path).Msg("Error parsing path")
104+
http.NotFound(w, r)
105+
return "", nil, false
106+
}
107+
108+
workspace := parts[0]
109+
110+
s.handlers.mu.RLock()
111+
h, ok := s.handlers.registry[workspace]
112+
s.handlers.mu.RUnlock()
113+
114+
if !ok {
115+
s.log.Error().Err(ErrNoHandlerFound).Str("workspace", workspace)
116+
sentry.CaptureError(ErrNoHandlerFound, sentry.Tags{"workspace": workspace})
117+
http.NotFound(w, r)
118+
return "", nil, false
119+
}
120+
121+
return workspace, h, true
122+
}
123+
124+
func getToken(r *http.Request) string {
125+
token := r.Header.Get("Authorization")
126+
token = strings.TrimPrefix(token, "Bearer ")
127+
token = strings.TrimPrefix(token, "bearer ")
128+
129+
return token
130+
}
131+
132+
func (s *Service) handleAuth(w http.ResponseWriter, r *http.Request, token string) bool {
133+
if !s.AppCfg.LocalDevelopment {
134+
if token == "" {
135+
http.Error(w, "Authorization header is required", http.StatusUnauthorized)
136+
return false
137+
}
138+
139+
if s.AppCfg.AuthenticateSchemaRequests {
140+
if s.isIntrospectionQuery(r) {
141+
ok, err := s.validateToken(r.Context(), token)
142+
if err != nil {
143+
s.log.Error().Err(err).Msg("error validating token with k8s")
144+
http.Error(w, "error validating token", http.StatusInternalServerError)
145+
return false
146+
}
147+
148+
if !ok {
149+
http.Error(w, "Provided token is not authorized to access the cluster", http.StatusUnauthorized)
150+
return false
151+
}
152+
}
153+
}
154+
}
155+
return true
156+
}
157+
158+
func (s *Service) isIntrospectionQuery(r *http.Request) bool {
159+
var params struct {
160+
Query string `json:"query"`
161+
}
162+
bodyBytes, err := io.ReadAll(r.Body)
163+
r.Body.Close()
164+
if err == nil {
165+
if err = json.Unmarshal(bodyBytes, &params); err == nil {
166+
if strings.Contains(params.Query, "__schema") || strings.Contains(params.Query, "__type") {
167+
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
168+
return true
169+
}
170+
}
171+
}
172+
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
173+
return false
174+
}
175+
176+
// validateToken uses the /version endpoint for a general authentication check.
177+
func (s *Service) validateToken(ctx context.Context, token string) (bool, error) {
178+
cfg := &rest.Config{
179+
Host: s.restCfg.Host,
180+
TLSClientConfig: rest.TLSClientConfig{
181+
CAFile: s.restCfg.TLSClientConfig.CAFile,
182+
CAData: s.restCfg.TLSClientConfig.CAData,
183+
},
184+
BearerToken: token,
185+
}
186+
187+
httpClient, err := rest.HTTPClientFor(cfg)
188+
if err != nil {
189+
return false, err
190+
}
191+
192+
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/version", cfg.Host), nil)
193+
if err != nil {
194+
return false, err
195+
}
196+
197+
resp, err := httpClient.Do(req)
198+
if err != nil {
199+
return false, err
200+
}
201+
resp.Body.Close()
202+
203+
switch resp.StatusCode {
204+
case http.StatusUnauthorized:
205+
return false, nil
206+
case http.StatusOK:
207+
return true, nil
208+
default:
209+
return false, fmt.Errorf("unexpected status code from /version: %d", resp.StatusCode)
210+
}
211+
}
212+
213+
func (s *Service) setContexts(r *http.Request, workspace, token string) *http.Request {
214+
if s.AppCfg.EnableKcp {
215+
r = r.WithContext(kontext.WithCluster(r.Context(), logicalcluster.Name(workspace)))
216+
}
217+
return r.WithContext(context.WithValue(r.Context(), TokenKey{}, token))
218+
}
219+
220+
func (s *Service) handleSubscription(w http.ResponseWriter, r *http.Request, schema *graphql.Schema) {
221+
// Set SSE headers
222+
w.Header().Set("Content-Type", "text/event-stream")
223+
w.Header().Set("Cache-Control", "no-cache")
224+
w.Header().Set("Connection", "keep-alive")
225+
226+
var params struct {
227+
Query string `json:"query"`
228+
OperationName string `json:"operationName"`
229+
Variables map[string]interface{} `json:"variables"`
230+
}
231+
232+
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
233+
http.Error(w, "Error parsing JSON request body", http.StatusBadRequest)
234+
return
235+
}
236+
237+
flusher := http.NewResponseController(w)
238+
r.Body.Close()
239+
240+
subscriptionParams := graphql.Params{
241+
Schema: *schema,
242+
RequestString: params.Query,
243+
VariableValues: params.Variables,
244+
OperationName: params.OperationName,
245+
Context: r.Context(),
246+
}
247+
248+
subscriptionChannel := graphql.Subscribe(subscriptionParams)
249+
for res := range subscriptionChannel {
250+
if res == nil {
251+
continue
252+
}
253+
254+
data, err := json.Marshal(res)
255+
if err != nil {
256+
s.log.Error().Err(err).Msg("Error marshalling subscription response")
257+
continue
258+
}
259+
260+
fmt.Fprintf(w, "event: next\ndata: %s\n\n", data)
261+
flusher.Flush()
262+
}
263+
264+
fmt.Fprint(w, "event: complete\n\n")
265+
}

gateway/manager/handler_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package manager_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/openmfp/kubernetes-graphql-gateway/gateway/manager"
9+
)
10+
11+
func TestServeHTTP_CORSPreflight(t *testing.T) {
12+
s := manager.NewManagerForTest()
13+
req := httptest.NewRequest(http.MethodOptions, "/testws/graphql", nil)
14+
w := httptest.NewRecorder()
15+
s.ServeHTTP(w, req)
16+
if w.Code != http.StatusOK {
17+
t.Errorf("expected 200 for CORS preflight, got %d", w.Code)
18+
}
19+
if w.Header().Get("Access-Control-Allow-Origin") == "" {
20+
t.Error("CORS headers not set")
21+
}
22+
}
23+
24+
func TestServeHTTP_InvalidWorkspace(t *testing.T) {
25+
s := manager.NewManagerForTest()
26+
req := httptest.NewRequest(http.MethodGet, "/invalidws/graphql", nil)
27+
w := httptest.NewRecorder()
28+
s.ServeHTTP(w, req)
29+
if w.Code != http.StatusNotFound {
30+
t.Errorf("expected 404 for invalid workspace, got %d", w.Code)
31+
}
32+
}
33+
34+
func TestServeHTTP_AuthRequired_NoToken(t *testing.T) {
35+
s := manager.NewManagerForTest()
36+
s.AppCfg.LocalDevelopment = false
37+
req := httptest.NewRequest(http.MethodPost, "/testws/graphql", nil)
38+
w := httptest.NewRecorder()
39+
s.ServeHTTP(w, req)
40+
if w.Code != http.StatusUnauthorized {
41+
t.Errorf("expected 401 for missing token, got %d", w.Code)
42+
}
43+
}

0 commit comments

Comments
 (0)