@@ -384,15 +384,278 @@ async def test_get_access_token_refresh_expired(mocker):
384384 secret = "some-secret"
385385 )
386386
387- # Patch method that does the refresh call
388- mocker .patch .object (client , "get_token_by_refresh_token" , return_value = {
387+ get_refresh_token_mock = mocker .patch .object (client , "get_token_by_refresh_token" , return_value = {
388+ "access_token" : "new_token" ,
389+ "expires_in" : 3600
390+ })
391+
392+ token = await client .get_access_token ()
393+ assert token == "new_token"
394+ mock_state_store .set .assert_awaited_once ()
395+ get_refresh_token_mock .assert_awaited_with ({
396+ "refresh_token" : "refresh_xyz"
397+ })
398+
399+ @pytest .mark .asyncio
400+ async def test_get_access_token_refresh_merging_default_scope (mocker ):
401+ mock_state_store = AsyncMock ()
402+ # expired token
403+ mock_state_store .get .return_value = {
404+ "refresh_token" : "refresh_xyz" ,
405+ "token_sets" : [
406+ {
407+ "audience" : "default" ,
408+ "access_token" : "expired_token" ,
409+ "expires_at" : int (time .time ()) - 500
410+ }
411+ ]
412+ }
413+
414+ client = ServerClient (
415+ domain = "auth0.local" ,
416+ client_id = "client_id" ,
417+ client_secret = "client_secret" ,
418+ transaction_store = AsyncMock (),
419+ state_store = mock_state_store ,
420+ secret = "some-secret" ,
421+ authorization_params = {
422+ "audience" : "default" ,
423+ "scope" : "openid profile email"
424+ }
425+ )
426+
427+ get_refresh_token_mock = mocker .patch .object (client , "get_token_by_refresh_token" , return_value = {
428+ "access_token" : "new_token" ,
429+ "expires_in" : 3600
430+ })
431+
432+ token = await client .get_access_token (scope = "foo:bar" )
433+ assert token == "new_token"
434+ mock_state_store .set .assert_awaited_once ()
435+ get_refresh_token_mock .assert_awaited_with ({
436+ "refresh_token" : "refresh_xyz" ,
437+ "audience" : "default" ,
438+ "scope" : "openid profile email foo:bar"
439+ })
440+
441+ @pytest .mark .asyncio
442+ async def test_get_access_token_refresh_with_auth_params_scope (mocker ):
443+ mock_state_store = AsyncMock ()
444+ # expired token
445+ mock_state_store .get .return_value = {
446+ "refresh_token" : "refresh_xyz" ,
447+ "token_sets" : [
448+ {
449+ "audience" : "default" ,
450+ "access_token" : "expired_token" ,
451+ "expires_at" : int (time .time ()) - 500
452+ }
453+ ]
454+ }
455+
456+ client = ServerClient (
457+ domain = "auth0.local" ,
458+ client_id = "client_id" ,
459+ client_secret = "client_secret" ,
460+ transaction_store = AsyncMock (),
461+ state_store = mock_state_store ,
462+ secret = "some-secret" ,
463+ authorization_params = {
464+ "scope" : "openid profile email"
465+ }
466+ )
467+
468+ get_refresh_token_mock = mocker .patch .object (client , "get_token_by_refresh_token" , return_value = {
469+ "access_token" : "new_token" ,
470+ "expires_in" : 3600
471+ })
472+
473+ token = await client .get_access_token ()
474+ assert token == "new_token"
475+ mock_state_store .set .assert_awaited_once ()
476+ get_refresh_token_mock .assert_awaited_with ({
477+ "refresh_token" : "refresh_xyz" ,
478+ "scope" : "openid profile email"
479+ })
480+
481+ @pytest .mark .asyncio
482+ async def test_get_access_token_refresh_with_auth_params_audience (mocker ):
483+ mock_state_store = AsyncMock ()
484+ # expired token
485+ mock_state_store .get .return_value = {
486+ "refresh_token" : "refresh_xyz" ,
487+ "token_sets" : [
488+ {
489+ "audience" : "my_audience" ,
490+ "access_token" : "expired_token" ,
491+ "expires_at" : int (time .time ()) - 500
492+ }
493+ ]
494+ }
495+
496+ client = ServerClient (
497+ domain = "auth0.local" ,
498+ client_id = "client_id" ,
499+ client_secret = "client_secret" ,
500+ transaction_store = AsyncMock (),
501+ state_store = mock_state_store ,
502+ secret = "some-secret" ,
503+ authorization_params = {
504+ "audience" : "my_audience"
505+ }
506+ )
507+
508+ get_refresh_token_mock = mocker .patch .object (client , "get_token_by_refresh_token" , return_value = {
389509 "access_token" : "new_token" ,
390510 "expires_in" : 3600
391511 })
392512
393513 token = await client .get_access_token ()
394514 assert token == "new_token"
395515 mock_state_store .set .assert_awaited_once ()
516+ get_refresh_token_mock .assert_awaited_with ({
517+ "refresh_token" : "refresh_xyz" ,
518+ "audience" : "my_audience"
519+ })
520+
521+ @pytest .mark .asyncio
522+ async def test_get_access_token_mrrt (mocker ):
523+ mock_state_store = AsyncMock ()
524+ # expired token
525+ mock_state_store .get .return_value = {
526+ "refresh_token" : "refresh_xyz" ,
527+ "token_sets" : [
528+ {
529+ "audience" : "default" ,
530+ "access_token" : "valid_token_for_other_audience" ,
531+ "expires_at" : int (time .time ()) + 500
532+ }
533+ ]
534+ }
535+
536+ client = ServerClient (
537+ domain = "auth0.local" ,
538+ client_id = "client_id" ,
539+ client_secret = "client_secret" ,
540+ transaction_store = AsyncMock (),
541+ state_store = mock_state_store ,
542+ secret = "some-secret"
543+ )
544+
545+ # Patch method that does the refresh call
546+ get_refresh_token_mock = mocker .patch .object (client , "get_token_by_refresh_token" , return_value = {
547+ "access_token" : "new_token" ,
548+ "expires_in" : 3600
549+ })
550+
551+ token = await client .get_access_token (
552+ audience = "some_audience" ,
553+ scope = "foo:bar"
554+ )
555+
556+ assert token == "new_token"
557+ mock_state_store .set .assert_awaited_once ()
558+ args , kwargs = mock_state_store .set .call_args
559+ stored_state = args [1 ]
560+ assert "token_sets" in stored_state
561+ assert len (stored_state ["token_sets" ]) == 2
562+ get_refresh_token_mock .assert_awaited_with ({
563+ "refresh_token" : "refresh_xyz" ,
564+ "audience" : "some_audience" ,
565+ "scope" : "foo:bar" ,
566+ })
567+
568+ @pytest .mark .asyncio
569+ async def test_get_access_token_mrrt_with_auth_params_scope (mocker ):
570+ mock_state_store = AsyncMock ()
571+ # expired token
572+ mock_state_store .get .return_value = {
573+ "refresh_token" : "refresh_xyz" ,
574+ "token_sets" : [
575+ {
576+ "audience" : "default" ,
577+ "access_token" : "valid_token_for_other_audience" ,
578+ "expires_at" : int (time .time ()) + 500
579+ }
580+ ]
581+ }
582+
583+ client = ServerClient (
584+ domain = "auth0.local" ,
585+ client_id = "client_id" ,
586+ client_secret = "client_secret" ,
587+ transaction_store = AsyncMock (),
588+ state_store = mock_state_store ,
589+ secret = "some-secret" ,
590+ authorization_params = {
591+ "audience" : "default" ,
592+ "scope" : {
593+ "default" : "openid profile email foo:bar" ,
594+ "some_audience" : "foo:bar"
595+ }
596+ }
597+ )
598+
599+ # Patch method that does the refresh call
600+ get_refresh_token_mock = mocker .patch .object (client , "get_token_by_refresh_token" , return_value = {
601+ "access_token" : "new_token" ,
602+ "expires_in" : 3600
603+ })
604+
605+ token = await client .get_access_token (
606+ audience = "some_audience"
607+ )
608+
609+ assert token == "new_token"
610+ mock_state_store .set .assert_awaited_once ()
611+ args , kwargs = mock_state_store .set .call_args
612+ stored_state = args [1 ]
613+ assert "token_sets" in stored_state
614+ assert len (stored_state ["token_sets" ]) == 2
615+ get_refresh_token_mock .assert_awaited_with ({
616+ "refresh_token" : "refresh_xyz" ,
617+ "audience" : "some_audience" ,
618+ "scope" : "foo:bar" ,
619+ })
620+
621+ @pytest .mark .asyncio
622+ async def test_get_access_token_from_store_with_multilpe_audiences (mocker ):
623+ mock_state_store = AsyncMock ()
624+ mock_state_store .get .return_value = {
625+ "refresh_token" : None ,
626+ "token_sets" : [
627+ {
628+ "audience" : "default" ,
629+ "access_token" : "token_from_store" ,
630+ "expires_at" : int (time .time ()) + 500
631+ },
632+ {
633+ "audience" : "some_audience" ,
634+ "access_token" : "other_token_from_store" ,
635+ "scope" : "foo:bar" ,
636+ "expires_at" : int (time .time ()) + 500
637+ }
638+ ]
639+ }
640+
641+ client = ServerClient (
642+ domain = "auth0.local" ,
643+ client_id = "client_id" ,
644+ client_secret = "client_secret" ,
645+ transaction_store = AsyncMock (),
646+ state_store = mock_state_store ,
647+ secret = "some-secret"
648+ )
649+
650+ get_refresh_token_mock = mocker .patch .object (client , "get_token_by_refresh_token" )
651+
652+ token = await client .get_access_token (
653+ audience = "some_audience" ,
654+ scope = "foo:bar"
655+ )
656+
657+ assert token == "other_token_from_store"
658+ get_refresh_token_mock .assert_not_awaited ()
396659
397660@pytest .mark .asyncio
398661async def test_get_access_token_for_connection_cached ():
0 commit comments