|
| 1 | +import base64 |
1 | 2 | import json
|
| 3 | +import re |
| 4 | +from datetime import datetime, timedelta |
2 | 5 |
|
| 6 | +import aiohttp |
| 7 | +import jwt |
3 | 8 | import pytest
|
4 | 9 | from azure.core.credentials import AzureKeyCredential
|
5 | 10 | from azure.search.documents.aio import SearchClient
|
6 | 11 | from azure.search.documents.indexes.models import SearchField, SearchIndex
|
| 12 | +from cryptography.hazmat.primitives import serialization |
| 13 | +from cryptography.hazmat.primitives.asymmetric import rsa |
7 | 14 |
|
8 | 15 | from core.authentication import AuthenticationHelper, AuthError
|
9 | 16 |
|
10 |
| -from .mocks import MockAsyncPageIterator |
| 17 | +from .mocks import MockAsyncPageIterator, MockResponse |
11 | 18 |
|
12 | 19 | MockSearchIndex = SearchIndex(
|
13 | 20 | name="test",
|
@@ -40,6 +47,36 @@ def create_search_client():
|
40 | 47 | return SearchClient(endpoint="", index_name="", credential=AzureKeyCredential(""))
|
41 | 48 |
|
42 | 49 |
|
| 50 | +def create_mock_jwt(kid="mock_kid", oid="OID_X"): |
| 51 | + # Create a payload with necessary claims |
| 52 | + payload = { |
| 53 | + "iss": "https://login.microsoftonline.com/TENANT_ID/v2.0", |
| 54 | + "sub": "AaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaA", |
| 55 | + "aud": "SERVER_APP", |
| 56 | + "exp": int((datetime.utcnow() + timedelta(hours=1)).timestamp()), |
| 57 | + "iat": int(datetime.utcnow().timestamp()), |
| 58 | + "nbf": int(datetime.utcnow().timestamp()), |
| 59 | + "name": "John Doe", |
| 60 | + "oid": oid, |
| 61 | + "preferred_username": "[email protected]", |
| 62 | + "rh": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA.", |
| 63 | + "tid": "22222222-2222-2222-2222-222222222222", |
| 64 | + "uti": "AbCdEfGhIjKlMnOp-ABCDEFG", |
| 65 | + "ver": "2.0", |
| 66 | + } |
| 67 | + |
| 68 | + # Create a header |
| 69 | + header = {"kid": kid, "alg": "RS256", "typ": "JWT"} |
| 70 | + |
| 71 | + # Create a mock private key (for signing) |
| 72 | + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
| 73 | + |
| 74 | + # Create the JWT |
| 75 | + token = jwt.encode(payload, private_key, algorithm="RS256", headers=header) |
| 76 | + |
| 77 | + return token, private_key.public_key(), payload |
| 78 | + |
| 79 | + |
43 | 80 | @pytest.mark.asyncio
|
44 | 81 | async def test_get_auth_claims_success(mock_confidential_client_success, mock_validate_token_success):
|
45 | 82 | helper = create_authentication_helper()
|
@@ -479,3 +516,136 @@ async def mock_search(self, *args, **kwargs):
|
479 | 516 | )
|
480 | 517 | assert filter is None
|
481 | 518 | assert called_search is False
|
| 519 | + |
| 520 | + |
| 521 | +@pytest.mark.asyncio |
| 522 | +async def test_create_pem_format(mock_confidential_client_success, mock_validate_token_success): |
| 523 | + helper = create_authentication_helper() |
| 524 | + mock_token, public_key, payload = create_mock_jwt(oid="OID_X") |
| 525 | + _, other_public_key, _ = create_mock_jwt(oid="OID_Y") |
| 526 | + mock_jwks = { |
| 527 | + "keys": [ |
| 528 | + # Include a key with a different KID to ensure the correct key is selected |
| 529 | + { |
| 530 | + "kty": "RSA", |
| 531 | + "kid": "other_mock_kid", |
| 532 | + "use": "sig", |
| 533 | + "n": base64.urlsafe_b64encode( |
| 534 | + other_public_key.public_numbers().n.to_bytes( |
| 535 | + (other_public_key.public_numbers().n.bit_length() + 7) // 8, byteorder="big" |
| 536 | + ) |
| 537 | + ) |
| 538 | + .decode("utf-8") |
| 539 | + .rstrip("="), |
| 540 | + "e": base64.urlsafe_b64encode( |
| 541 | + other_public_key.public_numbers().e.to_bytes( |
| 542 | + (other_public_key.public_numbers().e.bit_length() + 7) // 8, byteorder="big" |
| 543 | + ) |
| 544 | + ) |
| 545 | + .decode("utf-8") |
| 546 | + .rstrip("="), |
| 547 | + }, |
| 548 | + { |
| 549 | + "kty": "RSA", |
| 550 | + "kid": "mock_kid", |
| 551 | + "use": "sig", |
| 552 | + "n": base64.urlsafe_b64encode( |
| 553 | + public_key.public_numbers().n.to_bytes( |
| 554 | + (public_key.public_numbers().n.bit_length() + 7) // 8, byteorder="big" |
| 555 | + ) |
| 556 | + ) |
| 557 | + .decode("utf-8") |
| 558 | + .rstrip("="), |
| 559 | + "e": base64.urlsafe_b64encode( |
| 560 | + public_key.public_numbers().e.to_bytes( |
| 561 | + (public_key.public_numbers().e.bit_length() + 7) // 8, byteorder="big" |
| 562 | + ) |
| 563 | + ) |
| 564 | + .decode("utf-8") |
| 565 | + .rstrip("="), |
| 566 | + }, |
| 567 | + ] |
| 568 | + } |
| 569 | + |
| 570 | + pem_key = await helper.create_pem_format(mock_jwks, mock_token) |
| 571 | + |
| 572 | + # Assert that the result is bytes |
| 573 | + assert isinstance(pem_key, bytes), "create_pem_format should return bytes" |
| 574 | + |
| 575 | + # Convert bytes to string for regex matching |
| 576 | + pem_str = pem_key.decode("utf-8") |
| 577 | + |
| 578 | + # Assert that the key starts and ends with the correct markers |
| 579 | + assert pem_str.startswith("-----BEGIN PUBLIC KEY-----"), "PEM key should start with the correct marker" |
| 580 | + assert pem_str.endswith("-----END PUBLIC KEY-----\n"), "PEM key should end with the correct marker" |
| 581 | + |
| 582 | + # Assert that the format matches the structure of a PEM key |
| 583 | + pem_regex = r"^-----BEGIN PUBLIC KEY-----\n([A-Za-z0-9+/\n]+={0,2})\n-----END PUBLIC KEY-----\n$" |
| 584 | + assert re.match(pem_regex, pem_str), "PEM key format is incorrect" |
| 585 | + |
| 586 | + # Verify that the key can be used to decode the token |
| 587 | + try: |
| 588 | + decoded = jwt.decode( |
| 589 | + mock_token, key=pem_key, algorithms=["RS256"], audience=payload["aud"], issuer=payload["iss"] |
| 590 | + ) |
| 591 | + assert decoded["oid"] == payload["oid"], "Decoded token should contain correct OID" |
| 592 | + except Exception as e: |
| 593 | + pytest.fail(f"jwt.decode raised an unexpected exception: {str(e)}") |
| 594 | + |
| 595 | + # Try to load the key using cryptography library to ensure it's a valid PEM format |
| 596 | + try: |
| 597 | + loaded_public_key = serialization.load_pem_public_key(pem_key) |
| 598 | + assert isinstance(loaded_public_key, rsa.RSAPublicKey), "Loaded key should be an RSA public key" |
| 599 | + except Exception as e: |
| 600 | + pytest.fail(f"Failed to load PEM key: {str(e)}") |
| 601 | + |
| 602 | + |
| 603 | +@pytest.mark.asyncio |
| 604 | +async def test_validate_access_token(monkeypatch, mock_confidential_client_success): |
| 605 | + mock_token, public_key, payload = create_mock_jwt(oid="OID_X") |
| 606 | + |
| 607 | + def mock_get(*args, **kwargs): |
| 608 | + return MockResponse( |
| 609 | + status=200, |
| 610 | + text=json.dumps( |
| 611 | + { |
| 612 | + "keys": [ |
| 613 | + { |
| 614 | + "kty": "RSA", |
| 615 | + "use": "sig", |
| 616 | + "kid": "23nt", |
| 617 | + "x5t": "23nt", |
| 618 | + "n": "hu2SJ", |
| 619 | + "e": "AQAB", |
| 620 | + "x5c": ["MIIC/jCC"], |
| 621 | + "issuer": "https://login.microsoftonline.com/TENANT_ID/v2.0", |
| 622 | + }, |
| 623 | + { |
| 624 | + "kty": "RSA", |
| 625 | + "use": "sig", |
| 626 | + "kid": "MGLq", |
| 627 | + "x5t": "MGLq", |
| 628 | + "n": "yfNcG8", |
| 629 | + "e": "AQAB", |
| 630 | + "x5c": ["MIIC/jCC"], |
| 631 | + "issuer": "https://login.microsoftonline.com/TENANT_ID/v2.0", |
| 632 | + }, |
| 633 | + ] |
| 634 | + } |
| 635 | + ), |
| 636 | + ) |
| 637 | + |
| 638 | + monkeypatch.setattr(aiohttp.ClientSession, "get", mock_get) |
| 639 | + |
| 640 | + def mock_decode(*args, **kwargs): |
| 641 | + return payload |
| 642 | + |
| 643 | + monkeypatch.setattr(jwt, "decode", mock_decode) |
| 644 | + |
| 645 | + async def mock_create_pem_format(*args, **kwargs): |
| 646 | + return public_key |
| 647 | + |
| 648 | + monkeypatch.setattr(AuthenticationHelper, "create_pem_format", mock_create_pem_format) |
| 649 | + |
| 650 | + helper = create_authentication_helper() |
| 651 | + await helper.validate_access_token(mock_token) |
0 commit comments