Skip to content

Commit 95abdbe

Browse files
authored
Merge branch 'main' into 3318-RavenDB-state-store-new
2 parents 9b591b6 + eae3312 commit 95abdbe

File tree

5 files changed

+327
-17
lines changed

5 files changed

+327
-17
lines changed

middleware/http/oauth2/oauth2_middleware.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"net/http"
1919
"net/url"
2020
"reflect"
21+
"regexp"
2122
"strings"
2223

2324
"github.com/fasthttp-contrib/sessions"
@@ -42,6 +43,9 @@ type oAuth2MiddlewareMetadata struct {
4243
AuthHeaderName string `json:"authHeaderName" mapstructure:"authHeaderName"`
4344
RedirectURL string `json:"redirectURL" mapstructure:"redirectURL"`
4445
ForceHTTPS string `json:"forceHTTPS" mapstructure:"forceHTTPS"`
46+
PathFilter string `json:"pathFilter" mapstructure:"pathFilter"`
47+
48+
pathFilterRegex *regexp.Regexp
4549
}
4650

4751
// NewOAuth2Middleware returns a new oAuth2 middleware.
@@ -84,6 +88,15 @@ func (m *Middleware) GetHandler(ctx context.Context, metadata middleware.Metadat
8488

8589
return func(next http.Handler) http.Handler {
8690
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91+
if meta.pathFilterRegex != nil {
92+
matched := meta.pathFilterRegex.MatchString(r.URL.Path)
93+
if !matched {
94+
m.logger.Debugf("PathFilter %s didn't match %s! Skipping!", meta.PathFilter, r.URL.Path)
95+
next.ServeHTTP(w, r)
96+
return
97+
}
98+
}
99+
87100
session := sessions.Start(w, r)
88101

89102
if session.GetString(meta.AuthHeaderName) != "" {
@@ -153,6 +166,15 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2Mid
153166
if err != nil {
154167
return nil, err
155168
}
169+
170+
if middlewareMetadata.PathFilter != "" {
171+
rx, err := regexp.Compile(middlewareMetadata.PathFilter)
172+
if err != nil {
173+
return nil, err
174+
}
175+
middlewareMetadata.pathFilterRegex = rx
176+
}
177+
156178
return &middlewareMetadata, nil
157179
}
158180

middleware/http/oauth2/oauth2_middleware_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,52 @@ func TestOAuth2CreatesAuthorizationHeaderWhenInSessionState(t *testing.T) {
6161

6262
assert.Equal(t, "Bearer abcd", r.Header.Get("someHeader"))
6363
}
64+
65+
func TestOAuth2CreatesAuthorizationHeaderGetNativeMetadata(t *testing.T) {
66+
var metadata middleware.Metadata
67+
metadata.Properties = map[string]string{
68+
"clientID": "testId",
69+
"clientSecret": "testSecret",
70+
"scopes": "ascope",
71+
"authURL": "https://idp:9999",
72+
"tokenURL": "https://idp:9999",
73+
"redirectUrl": "https://localhost:9999",
74+
"authHeaderName": "someHeader",
75+
}
76+
77+
log := logger.NewLogger("oauth2.test")
78+
oauth2Middleware, ok := NewOAuth2Middleware(log).(*Middleware)
79+
require.True(t, ok)
80+
81+
tc := []struct {
82+
name string
83+
pathFilter string
84+
wantErr bool
85+
}{
86+
{name: "empty pathFilter", pathFilter: "", wantErr: false},
87+
{name: "wildcard pathFilter", pathFilter: ".*", wantErr: false},
88+
{name: "api path pathFilter", pathFilter: "/api/v1/users", wantErr: false},
89+
{name: "debug endpoint pathFilter", pathFilter: "^/debug/?$", wantErr: false},
90+
{name: "user id pathFilter", pathFilter: "^/user/[0-9]+$", wantErr: false},
91+
{name: "invalid wildcard pathFilter", pathFilter: "*invalid", wantErr: true},
92+
{name: "unclosed parenthesis pathFilter", pathFilter: "invalid(", wantErr: true},
93+
{name: "unopened parenthesis pathFilter", pathFilter: "invalid)", wantErr: true},
94+
}
95+
96+
for _, tt := range tc {
97+
t.Run(tt.name, func(t *testing.T) {
98+
metadata.Properties["pathFilter"] = tt.pathFilter
99+
nativeMetadata, err := oauth2Middleware.getNativeMetadata(metadata)
100+
if tt.wantErr {
101+
require.Error(t, err)
102+
} else {
103+
require.NoError(t, err)
104+
if tt.pathFilter != "" {
105+
require.NotNil(t, nativeMetadata.pathFilterRegex)
106+
} else {
107+
require.Nil(t, nativeMetadata.pathFilterRegex)
108+
}
109+
}
110+
})
111+
}
112+
}

middleware/http/oauth2clientcredentials/oauth2clientcredentials_middleware.go

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"net/http"
2222
"net/url"
2323
"reflect"
24+
"regexp"
2425
"strings"
2526
"time"
2627

@@ -43,6 +44,9 @@ type oAuth2ClientCredentialsMiddlewareMetadata struct {
4344
HeaderName string `json:"headerName" mapstructure:"headerName"`
4445
EndpointParamsQuery string `json:"endpointParamsQuery,omitempty" mapstructure:"endpointParamsQuery"`
4546
AuthStyle int `json:"authStyle" mapstructure:"authStyle"`
47+
PathFilter string `json:"pathFilter" mapstructure:"pathFilter"`
48+
49+
pathFilterRegex *regexp.Regexp
4650
}
4751

4852
// TokenProviderInterface provides a common interface to Mock the Token retrieval in unit tests.
@@ -69,7 +73,7 @@ type Middleware struct {
6973
tokenProvider TokenProviderInterface
7074
}
7175

72-
// GetHandler retruns the HTTP handler provided by the middleware.
76+
// GetHandler returns the HTTP handler provided by the middleware.
7377
func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata) (func(next http.Handler) http.Handler, error) {
7478
meta, err := m.getNativeMetadata(metadata)
7579
if err != nil {
@@ -98,27 +102,38 @@ func (m *Middleware) GetHandler(_ context.Context, metadata middleware.Metadata)
98102
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
99103
var headerValue string
100104

101-
// Check if valid token is in the cache
102-
cachedToken, found := m.tokenCache.Get(cacheKey)
103-
if !found {
104-
m.log.Debugf("Cached token not found, try get one")
105-
106-
token, err := m.tokenProvider.GetToken(r.Context(), conf)
107-
if err != nil {
108-
m.log.Errorf("Error acquiring token: %s", err)
105+
if meta.pathFilterRegex != nil {
106+
matched := meta.pathFilterRegex.MatchString(r.URL.Path)
107+
if !matched {
108+
m.log.Debugf("PathFilter %s didn't match %s! Skipping!", meta.PathFilter, r.URL.Path)
109+
next.ServeHTTP(w, r)
109110
return
110111
}
112+
}
111113

112-
tokenExpirationDuration := time.Until(token.Expiry)
113-
m.log.Debugf("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration)
114-
115-
headerValue = token.Type() + " " + token.AccessToken
116-
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
117-
} else {
114+
// Check if valid token is in the cache
115+
cachedToken, found := m.tokenCache.Get(cacheKey)
116+
if found {
118117
m.log.Debugf("Cached token found for key %s", cacheKey)
119118
headerValue = cachedToken.(string)
119+
r.Header.Add(meta.HeaderName, headerValue)
120+
next.ServeHTTP(w, r)
121+
return
122+
}
123+
124+
m.log.Infof("Cached token not found, attempting to retrieve a new one")
125+
token, err := m.tokenProvider.GetToken(r.Context(), conf)
126+
if err != nil {
127+
m.log.Errorf("Error acquiring token: %s", err)
128+
return
120129
}
121130

131+
tokenExpirationDuration := time.Until(token.Expiry)
132+
m.log.Infof("Token expires at %s (%s from now)", token.Expiry, tokenExpirationDuration)
133+
134+
headerValue = token.Type() + " " + token.AccessToken
135+
m.tokenCache.Set(cacheKey, headerValue, tokenExpirationDuration)
136+
122137
r.Header.Add(meta.HeaderName, headerValue)
123138
next.ServeHTTP(w, r)
124139
})
@@ -142,6 +157,14 @@ func (m *Middleware) getNativeMetadata(metadata middleware.Metadata) (*oAuth2Cli
142157
m.checkMetadataValueExists(&errorString, &middlewareMetadata.Scopes, "scopes")
143158
m.checkMetadataValueExists(&errorString, &middlewareMetadata.TokenURL, "tokenURL")
144159

160+
if middlewareMetadata.PathFilter != "" {
161+
rx, err := regexp.Compile(middlewareMetadata.PathFilter)
162+
if err != nil {
163+
errorString += "Parameter 'pathFilter' is not a valid regex: " + err.Error() + ". "
164+
}
165+
middlewareMetadata.pathFilterRegex = rx
166+
}
167+
145168
// Value-check AuthStyle
146169
if middlewareMetadata.AuthStyle < 0 || middlewareMetadata.AuthStyle > 2 {
147170
errorString += fmt.Sprintf("Parameter 'authStyle' can only have the values 0,1,2. Received: '%d'. ", middlewareMetadata.AuthStyle)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
Copyright 2025 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
*/
13+
14+
package oauth2clientcredentials
15+
16+
import (
17+
"fmt"
18+
"net/http"
19+
"net/http/httptest"
20+
"testing"
21+
"time"
22+
23+
"github.com/golang/mock/gomock"
24+
"github.com/stretchr/testify/require"
25+
"golang.org/x/oauth2"
26+
27+
"github.com/dapr/components-contrib/middleware"
28+
mock "github.com/dapr/components-contrib/middleware/http/oauth2clientcredentials/mocks"
29+
"github.com/dapr/kit/logger"
30+
)
31+
32+
func BenchmarkTestOAuth2ClientCredentialsGetHandler(b *testing.B) {
33+
mockCtrl := gomock.NewController(b)
34+
defer mockCtrl.Finish()
35+
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
36+
gomock.InOrder(
37+
mockTokenProvider.
38+
EXPECT().
39+
GetToken(gomock.Any()).
40+
Return(&oauth2.Token{
41+
AccessToken: "abcd",
42+
TokenType: "Bearer",
43+
Expiry: time.Now().Add(1 * time.Minute),
44+
}, nil).
45+
Times(1),
46+
)
47+
48+
var metadata middleware.Metadata
49+
metadata.Properties = map[string]string{
50+
"clientID": "testId",
51+
"clientSecret": "testSecret",
52+
"scopes": "ascope",
53+
"tokenURL": "https://localhost:9999",
54+
"headerName": "authorization",
55+
"authStyle": "1",
56+
}
57+
58+
log := logger.NewLogger("oauth2clientcredentials.test")
59+
oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
60+
require.True(b, ok)
61+
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
62+
handler, err := oauth2clientcredentialsMiddleware.GetHandler(b.Context(), metadata)
63+
require.NoError(b, err)
64+
65+
for i := range b.N {
66+
url := fmt.Sprintf("http://dapr.io/api/v1/users/%d", i)
67+
r := httptest.NewRequest(http.MethodGet, url, nil)
68+
w := httptest.NewRecorder()
69+
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
70+
}
71+
}
72+
73+
func BenchmarkTestOAuth2ClientCredentialsGetHandlerWithPathFilter(b *testing.B) {
74+
mockCtrl := gomock.NewController(b)
75+
defer mockCtrl.Finish()
76+
mockTokenProvider := mock.NewMockTokenProviderInterface(mockCtrl)
77+
gomock.InOrder(
78+
mockTokenProvider.
79+
EXPECT().
80+
GetToken(gomock.Any()).
81+
Return(&oauth2.Token{
82+
AccessToken: "abcd",
83+
TokenType: "Bearer",
84+
Expiry: time.Now().Add(1 * time.Minute),
85+
}, nil).
86+
Times(1),
87+
)
88+
89+
var metadata middleware.Metadata
90+
metadata.Properties = map[string]string{
91+
"clientID": "testId",
92+
"clientSecret": "testSecret",
93+
"scopes": "ascope",
94+
"tokenURL": "https://localhost:9999",
95+
"headerName": "authorization",
96+
"authStyle": "1",
97+
"pathFilter": "/api/v1/users/.*",
98+
}
99+
100+
log := logger.NewLogger("oauth2clientcredentials.test")
101+
oauth2clientcredentialsMiddleware, ok := NewOAuth2ClientCredentialsMiddleware(log).(*Middleware)
102+
require.True(b, ok)
103+
oauth2clientcredentialsMiddleware.SetTokenProvider(mockTokenProvider)
104+
handler, err := oauth2clientcredentialsMiddleware.GetHandler(b.Context(), metadata)
105+
require.NoError(b, err)
106+
107+
for i := range b.N {
108+
url := fmt.Sprintf("http://dapr.io/api/v1/users/%d", i)
109+
r := httptest.NewRequest(http.MethodGet, url, nil)
110+
w := httptest.NewRecorder()
111+
handler(http.HandlerFunc(mockedRequestHandler)).ServeHTTP(w, r)
112+
}
113+
}

0 commit comments

Comments
 (0)