Skip to content

Commit c62cdc3

Browse files
committed
Handle redirect_uri properly
1 parent c1d0167 commit c62cdc3

File tree

3 files changed

+104
-17
lines changed

3 files changed

+104
-17
lines changed

src/auth0_server_python/auth_server/server_client.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,7 +1289,6 @@ async def start_connect_account(
12891289
options: ConnectAccountOptions,
12901290
store_options: dict = None
12911291
) -> str:
1292-
12931292
# Get effective authorization params (merge defaults with provided ones)
12941293
auth_params = dict(self._default_authorization_params)
12951294
if options.authorization_params:
@@ -1298,24 +1297,22 @@ async def start_connect_account(
12981297
) if k not in INTERNAL_AUTHORIZE_PARAMS}
12991298
)
13001299

1300+
# Use the default redirect_uri if none is specified
1301+
redirect_uri = options.redirect_uri or self._redirect_uri
13011302
# Ensure we have a redirect_uri
1302-
if "redirect_uri" not in auth_params and not self._redirect_uri:
1303+
if not redirect_uri:
13031304
raise MissingRequiredArgumentError("redirect_uri")
13041305

1305-
# Use the default redirect_uri if none is specified
1306-
if "redirect_uri" not in auth_params and self._redirect_uri:
1307-
auth_params["redirect_uri"] = self._redirect_uri
1308-
13091306
# Generate PKCE code verifier and challenge
13101307
code_verifier = PKCE.generate_code_verifier()
13111308
code_challenge = PKCE.generate_code_challenge(code_verifier)
13121309

13131310
# State parameter to prevent CSRF
13141311
state = PKCE.generate_random_string(32)
1315-
1312+
13161313
connect_request = ConnectAccountRequest(
13171314
connection=options.connection,
1318-
redirect_uri = options.redirect_uri or auth_params["redirect_uri"],
1315+
redirect_uri = redirect_uri,
13191316
code_challenge=code_challenge,
13201317
code_challenge_method="S256",
13211318
state=state,
@@ -1335,7 +1332,8 @@ async def start_connect_account(
13351332
transaction_data = TransactionData(
13361333
code_verifier=code_verifier,
13371334
app_state=state,
1338-
auth_session = connect_response.auth_session
1335+
auth_session=connect_response.auth_session,
1336+
redirect_uri=redirect_uri
13391337
)
13401338

13411339
# Store the transaction data
@@ -1359,18 +1357,17 @@ async def complete_connect_account(
13591357

13601358
if not transaction_data:
13611359
raise MissingTransactionError()
1362-
1363-
# TODO //do I need to check error in redirect??
1364-
# TODO //handle no redirect uri??
1360+
13651361
access_token = await self.get_access_token(
13661362
audience=self._my_account_client.audienceIdentifier,
13671363
scope="create:me:connected_accounts",
13681364
store_options=store_options
13691365
)
1366+
13701367
request = CompleteConnectAccountRequest(
13711368
auth_session=transaction_data.auth_session,
13721369
connect_code=connect_code,
1373-
redirect_uri=self._redirect_uri,
1370+
redirect_uri=transaction_data.redirect_uri,
13741371
code_verifier=transaction_data.code_verifier
13751372
)
13761373

src/auth0_server_python/auth_types/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TransactionData(BaseModel):
8888
code_verifier: str
8989
app_state: Optional[Any] = None
9090
auth_session: Optional[str] = None
91+
redirect_uri: Optional[str] = None
9192

9293
class Config:
9394
extra = "allow" # Allow additional fields not defined in the model

src/auth0_server_python/tests/test_server_client.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,14 +1268,76 @@ async def test_start_connect_account_calls_connect_and_builds_url(mocker):
12681268
mock_transaction_store = AsyncMock()
12691269
mock_state_store = AsyncMock()
12701270

1271+
client = ServerClient(
1272+
domain="auth0.local",
1273+
client_id="<client_id>",
1274+
client_secret="<client_secret>",
1275+
state_store=mock_state_store,
1276+
transaction_store=mock_transaction_store,
1277+
secret="some-secret"
1278+
)
1279+
1280+
mocker.patch.object(client, "get_access_token", AsyncMock(return_value="<access_token>"))
1281+
mock_my_account_client = AsyncMock(MyAccountClient)
1282+
mocker.patch.object(client, "_my_account_client", mock_my_account_client)
1283+
mock_my_account_client.connect_account.return_value = ConnectAccountResponse(
1284+
auth_session="<auth_session>",
1285+
connect_uri="http://auth0.local/connected_accounts/connect",
1286+
connect_params=ConnectParams(
1287+
ticket="ticket123",
1288+
),
1289+
expires_in=300
1290+
)
1291+
1292+
mocker.patch.object(PKCE, "generate_random_string", return_value="<state>")
1293+
mocker.patch.object(PKCE, "generate_code_verifier", return_value="<code_verifier>")
1294+
mocker.patch.object(PKCE, "generate_code_challenge", return_value="<code_challenge>")
1295+
1296+
# Act
1297+
url = await client.start_connect_account(
1298+
options=ConnectAccountOptions(
1299+
connection="<connection>",
1300+
redirect_uri="/test_redirect_uri"
1301+
)
1302+
)
1303+
1304+
# Assert
1305+
assert url == "http://auth0.local/connected_accounts/connect?ticket=ticket123"
1306+
mock_my_account_client.connect_account.assert_awaited_with(
1307+
access_token="<access_token>",
1308+
request=ConnectAccountRequest(
1309+
connection="<connection>",
1310+
redirect_uri="/test_redirect_uri",
1311+
code_challenge_method="S256",
1312+
code_challenge="<code_challenge>",
1313+
state= "<state>"
1314+
)
1315+
)
1316+
mock_transaction_store.set.assert_awaited_with(
1317+
"_a0_tx:<state>",
1318+
TransactionData(
1319+
code_verifier="<code_verifier>",
1320+
app_state="<state>",
1321+
auth_session="<auth_session>",
1322+
redirect_uri="/test_redirect_uri"
1323+
),
1324+
options=ANY
1325+
)
1326+
1327+
@pytest.mark.asyncio
1328+
async def test_start_connect_account_default_redirect_uri(mocker):
1329+
# Setup
1330+
mock_transaction_store = AsyncMock()
1331+
mock_state_store = AsyncMock()
1332+
12711333
client = ServerClient(
12721334
domain="auth0.local",
12731335
client_id="<client_id>",
12741336
client_secret="<client_secret>",
12751337
state_store=mock_state_store,
12761338
transaction_store=mock_transaction_store,
12771339
secret="some-secret",
1278-
redirect_uri="/test_redirect_uri"
1340+
redirect_uri="/default_redirect_uri"
12791341
)
12801342

12811343
mocker.patch.object(client, "get_access_token", AsyncMock(return_value="<access_token>"))
@@ -1298,7 +1360,7 @@ async def test_start_connect_account_calls_connect_and_builds_url(mocker):
12981360
url = await client.start_connect_account(
12991361
options=ConnectAccountOptions(
13001362
connection="<connection>"
1301-
),
1363+
)
13021364
)
13031365

13041366
# Assert
@@ -1307,7 +1369,7 @@ async def test_start_connect_account_calls_connect_and_builds_url(mocker):
13071369
access_token="<access_token>",
13081370
request=ConnectAccountRequest(
13091371
connection="<connection>",
1310-
redirect_uri="/test_redirect_uri",
1372+
redirect_uri="/default_redirect_uri",
13111373
code_challenge_method="S256",
13121374
code_challenge="<code_challenge>",
13131375
state= "<state>"
@@ -1319,10 +1381,37 @@ async def test_start_connect_account_calls_connect_and_builds_url(mocker):
13191381
code_verifier="<code_verifier>",
13201382
app_state="<state>",
13211383
auth_session="<auth_session>",
1384+
redirect_uri="/default_redirect_uri"
13221385
),
13231386
options=ANY
13241387
)
13251388

1389+
@pytest.mark.asyncio
1390+
async def test_start_connect_account_no_redirect_uri(mocker):
1391+
# Setup
1392+
mock_transaction_store = AsyncMock()
1393+
mock_state_store = AsyncMock()
1394+
1395+
client = ServerClient(
1396+
domain="auth0.local",
1397+
client_id="<client_id>",
1398+
client_secret="<client_secret>",
1399+
state_store=mock_state_store,
1400+
transaction_store=mock_transaction_store,
1401+
secret="some-secret"
1402+
)
1403+
1404+
# Act
1405+
with pytest.raises(MissingRequiredArgumentError) as exc:
1406+
await client.start_connect_account(
1407+
options=ConnectAccountOptions(
1408+
connection="<connection>"
1409+
)
1410+
)
1411+
1412+
# Assert
1413+
assert "redirect_uri" in str(exc.value)
1414+
13261415
@pytest.mark.asyncio
13271416
async def test_complete_connect_account_calls_complete(mocker):
13281417
# Setup
@@ -1347,6 +1436,7 @@ async def test_complete_connect_account_calls_complete(mocker):
13471436
code_verifier="<code_verifier>",
13481437
app_state="<state>",
13491438
auth_session="<auth_session>",
1439+
redirect_uri="/test_redirect_uri"
13501440
)
13511441

13521442
# Act
@@ -1366,7 +1456,6 @@ async def test_complete_connect_account_calls_complete(mocker):
13661456
)
13671457
)
13681458

1369-
13701459
@pytest.mark.asyncio
13711460
async def test_complete_connect_account_no_transactions(mocker):
13721461
# Setup

0 commit comments

Comments
 (0)