Skip to content

Commit 4df7cbe

Browse files
Implement OAuth (#373)
1 parent c3951f6 commit 4df7cbe

File tree

9 files changed

+808
-7
lines changed

9 files changed

+808
-7
lines changed

go.mod

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
github.com/go-redis/redismock/v9 v9.2.0
1111
github.com/go-sql-driver/mysql v1.7.1
1212
github.com/gogo/protobuf v1.3.2
13+
github.com/golang-jwt/jwt/v5 v5.2.1
1314
github.com/google/uuid v1.6.0
1415
github.com/gorilla/mux v1.8.1
1516
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0
@@ -30,6 +31,7 @@ require (
3031
go.opentelemetry.io/otel/sdk/metric v1.24.0
3132
go.opentelemetry.io/otel/trace v1.24.0
3233
go.uber.org/mock v0.4.0
34+
golang.org/x/oauth2 v0.17.0
3335
golang.org/x/term v0.18.0
3436
google.golang.org/api v0.166.0
3537
google.golang.org/grpc v1.61.1
@@ -70,9 +72,8 @@ require (
7072
go.einride.tech/aip v0.66.0 // indirect
7173
go.opencensus.io v0.24.0 // indirect
7274
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.48.0 // indirect
73-
golang.org/x/crypto v0.19.0 // indirect
74-
golang.org/x/net v0.21.0 // indirect
75-
golang.org/x/oauth2 v0.17.0 // indirect
75+
golang.org/x/crypto v0.21.0 // indirect
76+
golang.org/x/net v0.22.0 // indirect
7677
golang.org/x/sync v0.6.0 // indirect
7778
golang.org/x/sys v0.18.0 // indirect
7879
golang.org/x/text v0.14.0 // indirect

go.sum

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9
6969
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
7070
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
7171
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
72+
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
73+
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
7274
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
7375
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
7476
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
@@ -227,8 +229,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
227229
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
228230
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
229231
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
230-
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
231-
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
232+
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
233+
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
232234
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
233235
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
234236
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
@@ -252,8 +254,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
252254
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
253255
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
254256
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
255-
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
256-
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
257+
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
258+
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
257259
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
258260
golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ=
259261
golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA=

pkg/gofr/gofr.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"os"
88
"strconv"
99
"sync"
10+
"time"
1011

1112
"go.opentelemetry.io/otel"
1213
"go.opentelemetry.io/otel/exporters/zipkin"
@@ -18,6 +19,7 @@ import (
1819

1920
"gofr.dev/pkg/gofr/config"
2021
"gofr.dev/pkg/gofr/container"
22+
"gofr.dev/pkg/gofr/http/middleware"
2123
"gofr.dev/pkg/gofr/logging"
2224
"gofr.dev/pkg/gofr/metrics"
2325
"gofr.dev/pkg/gofr/migration"
@@ -227,6 +229,17 @@ func (a *App) Migrate(migrationsMap map[int64]migration.Migrate) {
227229
migration.Run(migrationsMap, a.container)
228230
}
229231

232+
func (a *App) EnableOAuth(jwksEndpoint string, refreshInterval int) {
233+
a.AddHTTPService("gofr_oauth", jwksEndpoint)
234+
235+
oauthOption := middleware.OauthConfigs{
236+
Provider: a.container.GetHTTPService("gofr_oauth"),
237+
RefreshInterval: time.Second * time.Duration(refreshInterval),
238+
}
239+
240+
a.httpServer.router.Use(middleware.OAuth(middleware.NewOAuth(oauthOption)))
241+
}
242+
230243
func (a *App) initTracer() {
231244
tracerHost := a.Config.Get("TRACER_HOST")
232245
tracerPort := a.Config.GetOrDefault("TRACER_PORT", "9411")

pkg/gofr/gofr_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313

1414
"github.com/stretchr/testify/assert"
1515

16+
"gofr.dev/pkg/gofr/container"
17+
gofrHTTP "gofr.dev/pkg/gofr/http"
1618
"gofr.dev/pkg/gofr/logging"
1719
"gofr.dev/pkg/gofr/migration"
1820
"gofr.dev/pkg/gofr/testutil"
@@ -215,3 +217,48 @@ func Test_addRoute(t *testing.T) {
215217

216218
assert.Contains(t, logs, "handler called")
217219
}
220+
221+
func TestEnableBasicAuthWithFunc(t *testing.T) {
222+
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
223+
w.WriteHeader(http.StatusOK)
224+
}))
225+
226+
c := container.NewContainer(testutil.NewMockConfig(nil))
227+
228+
// Initialize a new App instance
229+
a := &App{
230+
httpServer: &httpServer{
231+
router: gofrHTTP.NewRouter(c),
232+
},
233+
container: c,
234+
}
235+
236+
a.httpServer.router.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
237+
fmt.Println(w, "Hello, world!")
238+
}))
239+
240+
a.EnableOAuth(jwksServer.URL, 600)
241+
242+
server := httptest.NewServer(a.httpServer.router)
243+
defer server.Close()
244+
245+
client := server.Client()
246+
247+
// Create a mock HTTP request
248+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, http.NoBody)
249+
if err != nil {
250+
t.Fatal(err)
251+
}
252+
253+
// Add a basic authorization header
254+
req.Header.Add("Authorization", "dXNlcjpwYXNzd29yZA==")
255+
256+
// Send the HTTP request
257+
resp, err := client.Do(req)
258+
if err != nil {
259+
t.Fatal(err)
260+
}
261+
defer resp.Body.Close()
262+
263+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "TestEnableBasicAuthWithFunc Failed!")
264+
}

pkg/gofr/http/middleware/oauth.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"crypto/rsa"
6+
"encoding/base64"
7+
"encoding/json"
8+
"fmt"
9+
"io"
10+
"math/big"
11+
"net/http"
12+
"strings"
13+
"time"
14+
15+
"github.com/golang-jwt/jwt/v5"
16+
)
17+
18+
type JWTClaim string
19+
20+
type PublicKeys struct {
21+
keys map[string]*rsa.PublicKey
22+
}
23+
24+
type JWKNotFound struct {
25+
}
26+
27+
func (i JWKNotFound) Error() string {
28+
return "JWKS Not Found"
29+
}
30+
31+
func (p *PublicKeys) Get(kid string) *rsa.PublicKey {
32+
kid = strings.TrimSpace(kid)
33+
34+
return p.keys[kid]
35+
}
36+
37+
type JWKSProvider interface {
38+
GetWithHeaders(ctx context.Context, path string, queryParams map[string]interface{},
39+
headers map[string]string) (*http.Response, error)
40+
}
41+
42+
type OauthConfigs struct {
43+
Provider JWKSProvider
44+
RefreshInterval time.Duration
45+
}
46+
47+
func NewOAuth(config OauthConfigs) PublicKeyProvider {
48+
var publicKeys PublicKeys
49+
50+
publicKeys.keys = make(map[string]*rsa.PublicKey)
51+
52+
go func() {
53+
for {
54+
resp, err := config.Provider.GetWithHeaders(context.Background(), "", nil, nil)
55+
if err != nil || resp == nil {
56+
continue
57+
}
58+
59+
body, err := io.ReadAll(resp.Body)
60+
if err != nil {
61+
continue
62+
}
63+
64+
resp.Body.Close()
65+
66+
var jwks JWKS
67+
68+
err = json.Unmarshal(body, &jwks)
69+
if err != nil {
70+
continue
71+
}
72+
73+
publicKeys.keys = publicKeyFromJWKS(jwks)
74+
75+
time.Sleep(config.RefreshInterval)
76+
}
77+
}()
78+
79+
return &publicKeys
80+
}
81+
82+
type PublicKeyProvider interface {
83+
Get(kid string) *rsa.PublicKey
84+
}
85+
86+
func OAuth(key PublicKeyProvider) func(inner http.Handler) http.Handler {
87+
return func(inner http.Handler) http.Handler {
88+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
89+
authHeader := r.Header.Get("Authorization")
90+
if authHeader == "" {
91+
http.Error(w, "Authorization header is required", http.StatusUnauthorized)
92+
return
93+
}
94+
95+
headerParts := strings.Split(authHeader, " ")
96+
if len(headerParts) != 2 || headerParts[0] != "Bearer" {
97+
http.Error(w, "Authorization header format must be Bearer {token}", http.StatusUnauthorized)
98+
return
99+
}
100+
101+
tokenString := headerParts[1]
102+
103+
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
104+
kid := token.Header["kid"]
105+
106+
jwks := key.Get(fmt.Sprint(kid))
107+
if jwks == nil {
108+
return nil, JWKNotFound{}
109+
}
110+
111+
return key.Get(fmt.Sprint(kid)), nil
112+
})
113+
114+
if err != nil {
115+
w.WriteHeader(http.StatusUnauthorized)
116+
_, _ = w.Write([]byte(err.Error()))
117+
118+
return
119+
}
120+
121+
ctx := context.WithValue(r.Context(), JWTClaim("JWTClaims"), token.Claims)
122+
*r = *r.Clone(ctx)
123+
124+
inner.ServeHTTP(w, r)
125+
})
126+
}
127+
}
128+
129+
// JWKS represents a JSON Web Key Set.
130+
type JWKS struct {
131+
Keys []JSONWebKey `json:"keys"`
132+
}
133+
134+
type JSONWebKey struct {
135+
ID string `json:"kid"`
136+
Type string `json:"kty"`
137+
138+
Modulus string `json:"n"`
139+
PublicExponent string `json:"e"`
140+
PrivateExponent string `json:"d"`
141+
}
142+
143+
// PublicKeyFromJWKS creates a public key from a JWKS and returns it in string format.
144+
func publicKeyFromJWKS(jwks JWKS) map[string]*rsa.PublicKey {
145+
if len(jwks.Keys) == 0 {
146+
return nil
147+
}
148+
149+
keys := make(map[string]*rsa.PublicKey)
150+
151+
for _, jwk := range jwks.Keys {
152+
var val = jwk
153+
154+
keys[jwk.ID], _ = rsaPublicKeyStringFromJWK(&val)
155+
}
156+
157+
// Store the result of rsaPublicKeyStringFromJWK before the next iteration
158+
159+
return keys
160+
}
161+
162+
func rsaPublicKeyStringFromJWK(jwk *JSONWebKey) (*rsa.PublicKey, error) {
163+
n, err := base64.RawURLEncoding.DecodeString(jwk.Modulus)
164+
if err != nil {
165+
return nil, err
166+
}
167+
168+
e, err := base64.RawURLEncoding.DecodeString(jwk.PublicExponent)
169+
if err != nil {
170+
return nil, err
171+
}
172+
173+
nInt := new(big.Int).SetBytes(n)
174+
eInt := new(big.Int).SetBytes(e)
175+
176+
rsaPublicKey := &rsa.PublicKey{
177+
N: nInt,
178+
E: int(eInt.Int64()),
179+
}
180+
181+
return rsaPublicKey, nil
182+
}

0 commit comments

Comments
 (0)