@@ -3,10 +3,13 @@ package cmd
33import (
44 "context"
55 "database/sql"
6+ "encoding/base64"
7+ "encoding/json"
68 "fmt"
79 "log"
810 "net/http"
911 "net/url"
12+ "strings"
1013 "time"
1114
1215 "code.icod.de/dalu/nethttpoidc"
@@ -25,6 +28,9 @@ import (
2528 "github.com/dlukt/graphql-backend-starter/ent"
2629 "github.com/dlukt/graphql-backend-starter/graph"
2730 "github.com/dlukt/graphql-backend-starter/middleware"
31+ "github.com/dlukt/graphql-backend-starter/rules/claims"
32+ "github.com/dlukt/graphql-backend-starter/rules/viewer"
33+ "github.com/gorilla/websocket"
2834 "github.com/rs/cors"
2935 "github.com/spf13/cobra"
3036 "github.com/vektah/gqlparser/v2/ast"
@@ -44,7 +50,7 @@ var graphqlCmd = &cobra.Command{
4450 RunE : func (cmd * cobra.Command , args []string ) error {
4551 fmt .Println ("graphql called" )
4652 setDatabaseURI ()
47-
53+
4854 var client * ent.Client
4955 if useSQLite {
5056 fmt .Println ("Running with SQLite" )
@@ -67,7 +73,7 @@ var graphqlCmd = &cobra.Command{
6773 log .Fatal ("opening ent client" , e )
6874 }
6975
70- srv := NewDefaultServer (graph .NewSchema (client ))
76+ srv := NewDefaultServer (graph .NewSchema (client ), client )
7177 srv .Use (entgql.Transactioner {TxOpener : client })
7278
7379 cfg := config .OidcConfigDev
@@ -153,11 +159,51 @@ func openDB(databaseURL string) *ent.Client {
153159 return ent .NewClient (ent .Driver (driver ))
154160}
155161
156- func NewDefaultServer (es graphql.ExecutableSchema ) * handler.Server {
162+ func NewDefaultServer (es graphql.ExecutableSchema , client * ent. Client ) * handler.Server {
157163 srv := handler .New (es )
158164
159165 srv .AddTransport (transport.Websocket {
160166 KeepAlivePingInterval : 10 * time .Second ,
167+ InitFunc : func (ctx context.Context , p transport.InitPayload ) (context.Context , * transport.InitPayload , error ) {
168+ // Ensure Ent client is present on websocket context as well
169+ ctx = ent .NewContext (ctx , client )
170+ var auth string
171+ if v := p .GetString ("Authorization" ); v != "" {
172+ auth = v
173+ }
174+ if auth == "" {
175+ if h , ok := p ["headers" ].(map [string ]any ); ok {
176+ if s , ok2 := h ["Authorization" ].(string ); ok2 {
177+ auth = s
178+ } else {
179+ for k , val := range h {
180+ if strings .ToLower (k ) == "authorization" {
181+ if s , ok3 := val .(string ); ok3 {
182+ auth = s
183+ }
184+ break
185+ }
186+ }
187+ }
188+ }
189+ }
190+ if auth == "" {
191+ return ctx , nil , nil
192+ }
193+ token := strings .TrimPrefix (auth , "Bearer " )
194+ if m := decodeJWTClaims (token ); m != nil {
195+ ctx = context .WithValue (ctx , options .DefaultClaimsContextKeyName , m )
196+ c := claimsFromMap (m )
197+ v := viewer .NewFromClaims (c )
198+ ctx = viewer .NewContext (ctx , v )
199+ }
200+ return ctx , nil , nil
201+ },
202+ Upgrader : websocket.Upgrader {
203+ CheckOrigin : func (r * http.Request ) bool {
204+ return true
205+ },
206+ },
161207 })
162208 srv .AddTransport (transport.Options {})
163209 srv .AddTransport (transport.GET {})
@@ -173,3 +219,37 @@ func NewDefaultServer(es graphql.ExecutableSchema) *handler.Server {
173219
174220 return srv
175221}
222+
223+ func decodeJWTClaims (token string ) map [string ]any {
224+ parts := strings .Split (token , "." )
225+ if len (parts ) < 2 {
226+ return nil
227+ }
228+ payload := parts [1 ]
229+ // Base64url decode
230+ // add padding if needed
231+ if l := len (payload ) % 4 ; l != 0 {
232+ payload += strings .Repeat ("=" , 4 - l )
233+ }
234+ b , err := base64 .URLEncoding .DecodeString (payload )
235+ if err != nil {
236+ return nil
237+ }
238+ var out map [string ]any
239+ if err := json .Unmarshal (b , & out ); err != nil {
240+ return nil
241+ }
242+ return out
243+ }
244+
245+ func claimsFromMap (m map [string ]any ) * claims.Claims {
246+ j , err := json .Marshal (m )
247+ if err != nil {
248+ return nil
249+ }
250+ var c claims.Claims
251+ if err := json .Unmarshal (j , & c ); err != nil {
252+ return nil
253+ }
254+ return & c
255+ }
0 commit comments