|
6 | 6 | import json |
7 | 7 | from dataclasses import InitVar, dataclass |
8 | 8 | from datetime import datetime |
9 | | -from typing import Any, Mapping, Optional, Union |
| 9 | +from typing import Any, Mapping, Optional, Union, cast |
10 | 10 |
|
11 | 11 | import jwt |
12 | 12 | from cryptography.hazmat.backends import default_backend |
13 | 13 | from cryptography.hazmat.primitives import serialization |
| 14 | +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey |
| 15 | +from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey |
| 16 | +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey |
| 17 | +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey |
14 | 18 | from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes |
15 | 19 |
|
16 | 20 | from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator |
17 | 21 | from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean |
18 | 22 | from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping |
19 | 23 | from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString |
20 | 24 |
|
| 25 | +# Type alias for keys that JWT library accepts |
| 26 | +JwtKeyTypes = Union[ |
| 27 | + RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey, str, bytes |
| 28 | +] |
| 29 | + |
21 | 30 |
|
22 | 31 | class JwtAlgorithm(str): |
23 | 32 | """ |
@@ -158,18 +167,21 @@ def _get_jwt_payload(self) -> dict[str, Any]: |
158 | 167 | payload["nbf"] = nbf |
159 | 168 | return payload |
160 | 169 |
|
161 | | - def _get_secret_key(self) -> PrivateKeyTypes | str | bytes: |
| 170 | + def _get_secret_key(self) -> JwtKeyTypes: |
162 | 171 | """ |
163 | 172 | Returns the secret key used to sign the JWT. |
164 | 173 | """ |
165 | 174 | secret_key: str = self._secret_key.eval(self.config, json_loads=json.loads) |
166 | 175 |
|
167 | 176 | if self._passphrase: |
168 | | - return serialization.load_pem_private_key( |
| 177 | + # Load encrypted private key and cast to JWT-compatible type |
| 178 | + # The JWT algorithms we support (RSA, ECDSA, EdDSA) use compatible key types |
| 179 | + private_key = serialization.load_pem_private_key( |
169 | 180 | secret_key.encode(), |
170 | 181 | password=self._passphrase.eval(self.config, json_loads=json.loads).encode(), |
171 | 182 | backend=default_backend(), |
172 | 183 | ) |
| 184 | + return cast(JwtKeyTypes, private_key) |
173 | 185 | else: |
174 | 186 | return ( |
175 | 187 | base64.b64encode(secret_key.encode()).decode() |
|
0 commit comments