@@ -37,6 +37,7 @@ import (
3737 "github.com/cisco-open/go-lanai/pkg/security/idp/passwdidp"
3838 "github.com/cisco-open/go-lanai/pkg/security/logout"
3939 "github.com/cisco-open/go-lanai/pkg/security/oauth2"
40+ "github.com/cisco-open/go-lanai/pkg/security/oauth2/auth"
4041 "github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/authorize"
4142 "github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/clientauth"
4243 "github.com/cisco-open/go-lanai/pkg/security/oauth2/auth/token"
@@ -59,6 +60,7 @@ import (
5960 . "github.com/cisco-open/go-lanai/test/utils/gomega"
6061 "github.com/cisco-open/go-lanai/test/webtest"
6162 "github.com/crewjam/saml"
63+ "github.com/golang-jwt/jwt/v4"
6264 "github.com/google/uuid"
6365 "github.com/onsi/gomega"
6466 . "github.com/onsi/gomega"
@@ -157,6 +159,7 @@ type intDI struct {
157159 Mocking sectest.MockingProperties
158160 TokenReader oauth2.TokenStoreReader
159161 SessionStore session.Store
162+ AuthReg auth.AuthorizationRegistry `optional:"true"`
160163}
161164
162165func TestWithMockedServer (t * testing.T ) {
@@ -189,6 +192,7 @@ func TestWithMockedServer(t *testing.T) {
189192 testdata .NewAuthServerConfigurer , //This configurer will set up mocked client store, mocked tenant store etc.
190193 testdata .NewResServerConfigurer ,
191194 testdata .NewMockedApprovalStore ,
195+ auth .NewLegacyTokenEnhancer ,
192196 ),
193197 ),
194198 test .GomegaSubTest (SubTestOAuth2AuthorizeWithPasswdIDP (di ), "TestOAuth2AuthorizeWithPasswdIDP" ),
@@ -234,6 +238,7 @@ func TestWithMockedServerWithoutFinalizer(t *testing.T) {
234238 sectest .BindMockingProperties ,
235239 testdata .NewAuthServerConfigurer ,
236240 testdata .NewResServerConfigurer ,
241+ auth .NewLegacyTokenEnhancer ,
237242 ),
238243 ),
239244 // a user has access to two tenants, switch from one to the other
@@ -242,6 +247,39 @@ func TestWithMockedServerWithoutFinalizer(t *testing.T) {
242247 )
243248}
244249
250+ func TestWithMockedServerWithCustomTokenGranter (t * testing.T ) {
251+ di := & intDI {}
252+ test .RunTest (context .Background (), t ,
253+ apptest .Bootstrap (),
254+ apptest .WithTimeout (2 * time .Minute ),
255+ webtest .WithMockedServer (),
256+ sectest .WithMockedMiddleware (sectest .MWEnableSession ()),
257+ apptest .WithModules (
258+ authserver .Module , resserver .Module ,
259+ passwdidp .Module , extsamlidp .Module , authorize .Module , samlidp .Module ,
260+ passwd .Module , formlogin .Module , logout .Module ,
261+ samlctx .Module , samlsp .Module ,
262+ basicauth .Module , clientauth .Module ,
263+ token .Module , access .Module , errorhandling .Module ,
264+ request_cache .Module , csrf .Module , session .Module ,
265+ redis .Module ,
266+ ),
267+ apptest .WithDI (di ),
268+ apptest .WithFxOptions (
269+ fx .Provide (
270+ IntegrationTestMocksProvider (),
271+ sectest .BindMockingProperties ,
272+ testdata .NewAuthServerConfigurer ,
273+ testdata .NewResServerConfigurer ,
274+ testdata .NewCustomTokenEnhancer ,
275+ testdata .NewCustomAuthRegistry ,
276+ testdata .NewCustomTokenGranter ,
277+ ),
278+ ),
279+ test .GomegaSubTest (SubTestCustomTokenGranter (di ), "TestCustomTokenGranter" ),
280+ )
281+ }
282+
245283/*************************
246284 Sub Tests
247285 *************************/
@@ -914,6 +952,31 @@ func SubTestOauth2SwitchTenant(
914952 }
915953}
916954
955+ func SubTestCustomTokenGranter (
956+ di * intDI ,
957+ ) test.GomegaSubTestFunc {
958+ return func (ctx context.Context , t * testing.T , g * gomega.WithT ) {
959+ req := webtest .NewRequest (ctx , http .MethodPost , "/v2/token" , customGrantReqBody (), withClientAuth ("custom-grant-client" , TestClientSecret ), tokenReqOptions ())
960+ resp := webtest .MustExec (ctx , req )
961+ g .Expect (resp ).ToNot (BeNil (), "response should not be nil" )
962+ g .Expect (resp .Response .StatusCode ).To (Equal (http .StatusOK ), "response should have correct status code" )
963+
964+ body , e := io .ReadAll (resp .Response .Body )
965+ g .Expect (e ).To (Succeed (), `token response body should be readable` )
966+ g .Expect (body ).To (HaveJsonPath ("$.access_token" ), "token response should have access_token" )
967+
968+ accessToken := oauth2 .NewDefaultAccessToken ("" )
969+ e = json .Unmarshal (body , accessToken )
970+ g .Expect (e ).ToNot (HaveOccurred ())
971+
972+ tk , _ , e := jwt .NewParser ().ParseUnverified (accessToken .Value (), jwt.MapClaims {})
973+ g .Expect (e ).ToNot (HaveOccurred ())
974+ g .Expect (tk .Claims .(jwt.MapClaims )["MyClaim" ]).To (Equal ("my_claim_value" ))
975+
976+ g .Expect (di .AuthReg .(* testdata.CustomAuthRegistry ).RegistrationCount ).To (Equal (1 ))
977+ }
978+ }
979+
917980/*************************
918981 Helpers
919982 *************************/
@@ -1103,6 +1166,12 @@ func passwordGrantReqBody(tenantId string, username string, password string) io.
11031166 return strings .NewReader (values .Encode ())
11041167}
11051168
1169+ func customGrantReqBody () io.Reader {
1170+ values := url.Values {}
1171+ values .Set (oauth2 .ParameterGrantType , "custom_grant" )
1172+ return strings .NewReader (values .Encode ())
1173+ }
1174+
11061175func tokenReqOptions () webtest.RequestOptions {
11071176 return func (req * http.Request ) {
11081177 req .Header .Set ("Content-Type" , "application/x-www-form-urlencoded" )
0 commit comments