Skip to content

Commit e295939

Browse files
committed
Add scopes per audience and merging of default/request scopes
1 parent 14b8b07 commit e295939

File tree

2 files changed

+318
-17
lines changed

2 files changed

+318
-17
lines changed

src/auth0_server_python/auth_server/server_client.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class ServerClient(Generic[TStoreOptions]):
4848
Main client for Auth0 server SDK. Handles authentication flows, session management,
4949
and token operations using Authlib for OIDC functionality.
5050
"""
51+
DEFAULT_AUDIENCE_STATE_KEY = "default"
5152

5253
def __init__(
5354
self,
@@ -292,7 +293,7 @@ async def complete_interactive_login(
292293

293294
# Build a token set using the token response data
294295
token_set = TokenSet(
295-
audience=transaction_data.audience or "default",
296+
audience=transaction_data.audience or self.DEFAULT_AUDIENCE_STATE_KEY,
296297
access_token=token_response.get("access_token", ""),
297298
scope=token_response.get("scope", ""),
298299
expires_at=int(time.time()) +
@@ -511,7 +512,7 @@ async def login_backchannel(
511512
existing_state_data = await self._state_store.get(self._state_identifier, store_options)
512513

513514
audience = self._default_authorization_params.get(
514-
"audience", "default")
515+
"audience", self.DEFAULT_AUDIENCE_STATE_KEY)
515516

516517
state_data = State.update_state_data(
517518
audience,
@@ -586,12 +587,17 @@ async def get_access_token(
586587
"""
587588
state_data = await self._state_store.get(self._state_identifier, store_options)
588589

589-
# Get audience and scope from options or use defaults
590590
auth_params = self._default_authorization_params or {}
591+
592+
# Get audience options or use defaults
591593
if not audience:
592-
audience = auth_params.get("audience", "default")
593-
if not scope:
594-
scope = auth_params.get("scope")
594+
audience = auth_params.get("audience", None)
595+
596+
merged_scope = self._get_scope_to_request(
597+
scope,
598+
auth_params.get("scope", None),
599+
audience or self.DEFAULT_AUDIENCE_STATE_KEY
600+
)
595601

596602
if state_data and hasattr(state_data, "dict") and callable(state_data.dict):
597603
state_data_dict = state_data.dict()
@@ -601,10 +607,8 @@ async def get_access_token(
601607
# Find matching token set
602608
token_set = None
603609
if state_data_dict and "token_sets" in state_data_dict:
604-
for ts in state_data_dict["token_sets"]:
605-
if ts.get("audience") == audience and (not scope or ts.get("scope") == scope):
606-
token_set = ts
607-
break
610+
token_set = self._find_matching_token_set(
611+
state_data_dict["token_sets"], audience or self.DEFAULT_AUDIENCE_STATE_KEY, merged_scope)
608612

609613
# If token is valid, return it
610614
if token_set and token_set.get("expires_at", 0) > time.time():
@@ -619,11 +623,14 @@ async def get_access_token(
619623

620624
# Get new token with refresh token
621625
try:
622-
token_endpoint_response = await self.get_token_by_refresh_token({
623-
"refresh_token": state_data_dict["refresh_token"],
624-
"audience": audience,
625-
"scope": scope
626-
})
626+
request_body = {"refresh_token": state_data_dict["refresh_token"]}
627+
if audience:
628+
request_body["audience"] = audience
629+
630+
if merged_scope:
631+
request_body["scope"] = merged_scope
632+
633+
token_endpoint_response = await self.get_token_by_refresh_token(request_body)
627634

628635
# Update state data with new token
629636
existing_state_data = await self._state_store.get(self._state_identifier, store_options)
@@ -642,6 +649,37 @@ async def get_access_token(
642649
f"Failed to get token with refresh token: {str(e)}"
643650
)
644651

652+
def _get_scope_to_request(
653+
self,
654+
request_scopes: Optional[str],
655+
default_scopes: Optional[str] | Optional[dict[str, str]],
656+
audience: Optional[str]
657+
) -> Optional[str]:
658+
# For backwards compatibility, allow scope to be a single string
659+
# or dictionary by audience for MRRT
660+
if isinstance(default_scopes, dict) and audience in default_scopes:
661+
default_scopes = default_scopes[audience]
662+
663+
default_scopes_list = (default_scopes or "").split()
664+
request_scopes_list = (request_scopes or "").split()
665+
666+
merged_scopes = default_scopes_list + [x for x in request_scopes_list if x not in default_scopes_list]
667+
return " ".join(merged_scopes) if merged_scopes else None
668+
669+
670+
def _find_matching_token_set(
671+
self,
672+
token_sets: list[dict[str, Any]],
673+
audience: Optional[str],
674+
scope: Optional[str]
675+
) -> Optional[dict[str, Any]]:
676+
for token_set in token_sets:
677+
token_set_audience = token_set.get("audience")
678+
matches_audience = token_set_audience == audience
679+
matches_scope = not scope or token_set.get("scope", None) == scope
680+
if matches_audience and matches_scope:
681+
return token_set
682+
645683
async def get_access_token_for_connection(
646684
self,
647685
options: dict[str, Any],

src/auth0_server_python/tests/test_server_client.py

Lines changed: 265 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,15 +384,278 @@ async def test_get_access_token_refresh_expired(mocker):
384384
secret="some-secret"
385385
)
386386

387-
# Patch method that does the refresh call
388-
mocker.patch.object(client, "get_token_by_refresh_token", return_value={
387+
get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={
388+
"access_token": "new_token",
389+
"expires_in": 3600
390+
})
391+
392+
token = await client.get_access_token()
393+
assert token == "new_token"
394+
mock_state_store.set.assert_awaited_once()
395+
get_refresh_token_mock.assert_awaited_with({
396+
"refresh_token": "refresh_xyz"
397+
})
398+
399+
@pytest.mark.asyncio
400+
async def test_get_access_token_refresh_merging_default_scope(mocker):
401+
mock_state_store = AsyncMock()
402+
# expired token
403+
mock_state_store.get.return_value = {
404+
"refresh_token": "refresh_xyz",
405+
"token_sets": [
406+
{
407+
"audience": "default",
408+
"access_token": "expired_token",
409+
"expires_at": int(time.time()) - 500
410+
}
411+
]
412+
}
413+
414+
client = ServerClient(
415+
domain="auth0.local",
416+
client_id="client_id",
417+
client_secret="client_secret",
418+
transaction_store=AsyncMock(),
419+
state_store=mock_state_store,
420+
secret="some-secret",
421+
authorization_params= {
422+
"audience": "default",
423+
"scope": "openid profile email"
424+
}
425+
)
426+
427+
get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={
428+
"access_token": "new_token",
429+
"expires_in": 3600
430+
})
431+
432+
token = await client.get_access_token(scope="foo:bar")
433+
assert token == "new_token"
434+
mock_state_store.set.assert_awaited_once()
435+
get_refresh_token_mock.assert_awaited_with({
436+
"refresh_token": "refresh_xyz",
437+
"audience": "default",
438+
"scope": "openid profile email foo:bar"
439+
})
440+
441+
@pytest.mark.asyncio
442+
async def test_get_access_token_refresh_with_auth_params_scope(mocker):
443+
mock_state_store = AsyncMock()
444+
# expired token
445+
mock_state_store.get.return_value = {
446+
"refresh_token": "refresh_xyz",
447+
"token_sets": [
448+
{
449+
"audience": "default",
450+
"access_token": "expired_token",
451+
"expires_at": int(time.time()) - 500
452+
}
453+
]
454+
}
455+
456+
client = ServerClient(
457+
domain="auth0.local",
458+
client_id="client_id",
459+
client_secret="client_secret",
460+
transaction_store=AsyncMock(),
461+
state_store=mock_state_store,
462+
secret="some-secret",
463+
authorization_params= {
464+
"scope": "openid profile email"
465+
}
466+
)
467+
468+
get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={
469+
"access_token": "new_token",
470+
"expires_in": 3600
471+
})
472+
473+
token = await client.get_access_token()
474+
assert token == "new_token"
475+
mock_state_store.set.assert_awaited_once()
476+
get_refresh_token_mock.assert_awaited_with({
477+
"refresh_token": "refresh_xyz",
478+
"scope": "openid profile email"
479+
})
480+
481+
@pytest.mark.asyncio
482+
async def test_get_access_token_refresh_with_auth_params_audience(mocker):
483+
mock_state_store = AsyncMock()
484+
# expired token
485+
mock_state_store.get.return_value = {
486+
"refresh_token": "refresh_xyz",
487+
"token_sets": [
488+
{
489+
"audience": "my_audience",
490+
"access_token": "expired_token",
491+
"expires_at": int(time.time()) - 500
492+
}
493+
]
494+
}
495+
496+
client = ServerClient(
497+
domain="auth0.local",
498+
client_id="client_id",
499+
client_secret="client_secret",
500+
transaction_store=AsyncMock(),
501+
state_store=mock_state_store,
502+
secret="some-secret",
503+
authorization_params= {
504+
"audience": "my_audience"
505+
}
506+
)
507+
508+
get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={
389509
"access_token": "new_token",
390510
"expires_in": 3600
391511
})
392512

393513
token = await client.get_access_token()
394514
assert token == "new_token"
395515
mock_state_store.set.assert_awaited_once()
516+
get_refresh_token_mock.assert_awaited_with({
517+
"refresh_token": "refresh_xyz",
518+
"audience": "my_audience"
519+
})
520+
521+
@pytest.mark.asyncio
522+
async def test_get_access_token_mrrt(mocker):
523+
mock_state_store = AsyncMock()
524+
# expired token
525+
mock_state_store.get.return_value = {
526+
"refresh_token": "refresh_xyz",
527+
"token_sets": [
528+
{
529+
"audience": "default",
530+
"access_token": "valid_token_for_other_audience",
531+
"expires_at": int(time.time()) + 500
532+
}
533+
]
534+
}
535+
536+
client = ServerClient(
537+
domain="auth0.local",
538+
client_id="client_id",
539+
client_secret="client_secret",
540+
transaction_store=AsyncMock(),
541+
state_store=mock_state_store,
542+
secret="some-secret"
543+
)
544+
545+
# Patch method that does the refresh call
546+
get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={
547+
"access_token": "new_token",
548+
"expires_in": 3600
549+
})
550+
551+
token = await client.get_access_token(
552+
audience="some_audience",
553+
scope="foo:bar"
554+
)
555+
556+
assert token == "new_token"
557+
mock_state_store.set.assert_awaited_once()
558+
args, kwargs = mock_state_store.set.call_args
559+
stored_state = args[1]
560+
assert "token_sets" in stored_state
561+
assert len(stored_state["token_sets"]) == 2
562+
get_refresh_token_mock.assert_awaited_with({
563+
"refresh_token": "refresh_xyz",
564+
"audience": "some_audience",
565+
"scope": "foo:bar",
566+
})
567+
568+
@pytest.mark.asyncio
569+
async def test_get_access_token_mrrt_with_auth_params_scope(mocker):
570+
mock_state_store = AsyncMock()
571+
# expired token
572+
mock_state_store.get.return_value = {
573+
"refresh_token": "refresh_xyz",
574+
"token_sets": [
575+
{
576+
"audience": "default",
577+
"access_token": "valid_token_for_other_audience",
578+
"expires_at": int(time.time()) + 500
579+
}
580+
]
581+
}
582+
583+
client = ServerClient(
584+
domain="auth0.local",
585+
client_id="client_id",
586+
client_secret="client_secret",
587+
transaction_store=AsyncMock(),
588+
state_store=mock_state_store,
589+
secret="some-secret",
590+
authorization_params= {
591+
"audience": "default",
592+
"scope": {
593+
"default": "openid profile email foo:bar",
594+
"some_audience": "foo:bar"
595+
}
596+
}
597+
)
598+
599+
# Patch method that does the refresh call
600+
get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token", return_value={
601+
"access_token": "new_token",
602+
"expires_in": 3600
603+
})
604+
605+
token = await client.get_access_token(
606+
audience="some_audience"
607+
)
608+
609+
assert token == "new_token"
610+
mock_state_store.set.assert_awaited_once()
611+
args, kwargs = mock_state_store.set.call_args
612+
stored_state = args[1]
613+
assert "token_sets" in stored_state
614+
assert len(stored_state["token_sets"]) == 2
615+
get_refresh_token_mock.assert_awaited_with({
616+
"refresh_token": "refresh_xyz",
617+
"audience": "some_audience",
618+
"scope": "foo:bar",
619+
})
620+
621+
@pytest.mark.asyncio
622+
async def test_get_access_token_from_store_with_multilpe_audiences(mocker):
623+
mock_state_store = AsyncMock()
624+
mock_state_store.get.return_value = {
625+
"refresh_token": None,
626+
"token_sets": [
627+
{
628+
"audience": "default",
629+
"access_token": "token_from_store",
630+
"expires_at": int(time.time()) + 500
631+
},
632+
{
633+
"audience": "some_audience",
634+
"access_token": "other_token_from_store",
635+
"scope": "foo:bar",
636+
"expires_at": int(time.time()) + 500
637+
}
638+
]
639+
}
640+
641+
client = ServerClient(
642+
domain="auth0.local",
643+
client_id="client_id",
644+
client_secret="client_secret",
645+
transaction_store=AsyncMock(),
646+
state_store=mock_state_store,
647+
secret="some-secret"
648+
)
649+
650+
get_refresh_token_mock = mocker.patch.object(client, "get_token_by_refresh_token")
651+
652+
token = await client.get_access_token(
653+
audience="some_audience",
654+
scope="foo:bar"
655+
)
656+
657+
assert token == "other_token_from_store"
658+
get_refresh_token_mock.assert_not_awaited()
396659

397660
@pytest.mark.asyncio
398661
async def test_get_access_token_for_connection_cached():

0 commit comments

Comments
 (0)