Skip to content

Commit f8fb6b0

Browse files
committed
allow configuring the signed JWT token to be injected as a request_option for JwtAuthenticator
1 parent 25132b6 commit f8fb6b0

File tree

6 files changed

+191
-34
lines changed

6 files changed

+191
-34
lines changed

airbyte_cdk/sources/declarative/auth/jwt.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,23 @@
66
import json
77
from dataclasses import InitVar, dataclass
88
from datetime import datetime
9-
from typing import Any, Mapping, Optional, Union, cast
9+
from typing import Any, Mapping, MutableMapping, Optional, Union, cast
1010

1111
import jwt
1212
from cryptography.hazmat.primitives import serialization
1313
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
1414
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey
1515
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
1616
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
17-
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
1817

1918
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
2019
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
2120
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
2221
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
22+
from airbyte_cdk.sources.declarative.requesters.request_option import (
23+
RequestOption,
24+
RequestOptionType,
25+
)
2326

2427
# Type alias for keys that JWT library accepts
2528
JwtKeyTypes = Union[
@@ -86,6 +89,7 @@ class JwtAuthenticator(DeclarativeAuthenticator):
8689
additional_jwt_headers: Optional[Mapping[str, Any]] = None
8790
additional_jwt_payload: Optional[Mapping[str, Any]] = None
8891
passphrase: Optional[Union[InterpolatedString, str]] = None
92+
request_option: Optional[RequestOption] = None
8993

9094
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
9195
self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters)
@@ -121,6 +125,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
121125
else None
122126
)
123127

128+
# When we first implemented the JWT authenticator, we assumed that the signed token was always supposed
129+
# to be loaded into the request headers under the `Authorization` key. This is not always the case, but
130+
# this default option allows for backwards compatibility to be retained for existing connectors
131+
self._request_option = self.request_option or RequestOption(
132+
inject_into=RequestOptionType.header, field_name="Authorization", parameters=parameters
133+
)
134+
124135
def _get_jwt_headers(self) -> dict[str, Any]:
125136
"""
126137
Builds and returns the headers used when signing the JWT.
@@ -213,7 +224,8 @@ def _get_header_prefix(self) -> Union[str, None]:
213224

214225
@property
215226
def auth_header(self) -> str:
216-
return "Authorization"
227+
options = self._get_request_options(RequestOptionType.header)
228+
return next(iter(options.keys()), "")
217229

218230
@property
219231
def token(self) -> str:
@@ -222,3 +234,18 @@ def token(self) -> str:
222234
if self._get_header_prefix()
223235
else self._get_signed_token()
224236
)
237+
238+
def get_request_params(self) -> Mapping[str, Any]:
239+
return self._get_request_options(RequestOptionType.request_parameter)
240+
241+
def get_request_body_data(self) -> Union[Mapping[str, Any], str]:
242+
return self._get_request_options(RequestOptionType.body_data)
243+
244+
def get_request_body_json(self) -> Mapping[str, Any]:
245+
return self._get_request_options(RequestOptionType.body_json)
246+
247+
def _get_request_options(self, option_type: RequestOptionType) -> Mapping[str, Any]:
248+
options: MutableMapping[str, Any] = {}
249+
if self._request_option.inject_into == option_type:
250+
self._request_option.inject_into_request(options, self.token, self.config)
251+
return options

airbyte_cdk/sources/declarative/declarative_component_schema.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,10 @@ definitions:
12761276
type: string
12771277
examples:
12781278
- "{{ config['passphrase'] }}"
1279+
request_option:
1280+
title: Request Option
1281+
description: A request option describing where the signed JWT token that is generated should be injected into the outbound API request.
1282+
"$ref": "#/definitions/RequestOption"
12791283
$parameters:
12801284
type: object
12811285
additionalProperties: true

airbyte_cdk/sources/declarative/models/declarative_component_schema.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,35 @@ class Algorithm(Enum):
350350
EdDSA = "EdDSA"
351351

352352

353+
class InjectInto(Enum):
354+
request_parameter = "request_parameter"
355+
header = "header"
356+
body_data = "body_data"
357+
body_json = "body_json"
358+
359+
360+
class RequestOption(BaseModel):
361+
type: Literal["RequestOption"]
362+
inject_into: InjectInto = Field(
363+
...,
364+
description="Configures where the descriptor should be set on the HTTP requests. Note that request parameters that are already encoded in the URL path will not be duplicated.",
365+
examples=["request_parameter", "header", "body_data", "body_json"],
366+
title="Inject Into",
367+
)
368+
field_name: Optional[str] = Field(
369+
None,
370+
description="Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder.",
371+
examples=["segment_id"],
372+
title="Field Name",
373+
)
374+
field_path: Optional[List[str]] = Field(
375+
None,
376+
description="Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)",
377+
examples=[["data", "viewer", "id"]],
378+
title="Field Path",
379+
)
380+
381+
353382
class JwtHeaders(BaseModel):
354383
class Config:
355384
extra = Extra.forbid
@@ -454,6 +483,11 @@ class JwtAuthenticator(BaseModel):
454483
examples=["{{ config['passphrase'] }}"],
455484
title="Passphrase",
456485
)
486+
request_option: Optional[RequestOption] = Field(
487+
None,
488+
description="A request option describing where the generated JWT token should be injected into the outbound API request.",
489+
title="Request Option",
490+
)
457491
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
458492

459493

@@ -1294,35 +1328,6 @@ class RequestPath(BaseModel):
12941328
type: Literal["RequestPath"]
12951329

12961330

1297-
class InjectInto(Enum):
1298-
request_parameter = "request_parameter"
1299-
header = "header"
1300-
body_data = "body_data"
1301-
body_json = "body_json"
1302-
1303-
1304-
class RequestOption(BaseModel):
1305-
type: Literal["RequestOption"]
1306-
inject_into: InjectInto = Field(
1307-
...,
1308-
description="Configures where the descriptor should be set on the HTTP requests. Note that request parameters that are already encoded in the URL path will not be duplicated.",
1309-
examples=["request_parameter", "header", "body_data", "body_json"],
1310-
title="Inject Into",
1311-
)
1312-
field_name: Optional[str] = Field(
1313-
None,
1314-
description="Configures which key should be used in the location that the descriptor is being injected into. We hope to eventually deprecate this field in favor of `field_path` for all request_options, but must currently maintain it for backwards compatibility in the Builder.",
1315-
examples=["segment_id"],
1316-
title="Field Name",
1317-
)
1318-
field_path: Optional[List[str]] = Field(
1319-
None,
1320-
description="Configures a path to be used for nested structures in JSON body requests (e.g. GraphQL queries)",
1321-
examples=[["data", "viewer", "id"]],
1322-
title="Field Path",
1323-
)
1324-
1325-
13261331
class Schemas(BaseModel):
13271332
pass
13281333

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2683,12 +2683,16 @@ def create_json_file_schema_loader(
26832683
file_path=model.file_path or "", config=config, parameters=model.parameters or {}
26842684
)
26852685

2686-
@staticmethod
26872686
def create_jwt_authenticator(
2688-
model: JwtAuthenticatorModel, config: Config, **kwargs: Any
2687+
self, model: JwtAuthenticatorModel, config: Config, **kwargs: Any
26892688
) -> JwtAuthenticator:
26902689
jwt_headers = model.jwt_headers or JwtHeadersModel(kid=None, typ="JWT", cty=None)
26912690
jwt_payload = model.jwt_payload or JwtPayloadModel(iss=None, sub=None, aud=None)
2691+
request_option = (
2692+
self._create_component_from_model(model.request_option, config)
2693+
if model.request_option
2694+
else None
2695+
)
26922696
return JwtAuthenticator(
26932697
config=config,
26942698
parameters=model.parameters or {},
@@ -2706,6 +2710,7 @@ def create_jwt_authenticator(
27062710
additional_jwt_headers=model.additional_jwt_headers,
27072711
additional_jwt_payload=model.additional_jwt_payload,
27082712
passphrase=model.passphrase,
2713+
request_option=request_option,
27092714
)
27102715

27112716
def create_list_partition_router(

unit_tests/sources/declarative/auth/test_jwt.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from cryptography.hazmat.primitives.asymmetric import rsa
1414

1515
from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
16+
from airbyte_cdk.sources.declarative.requesters.request_option import (
17+
RequestOption,
18+
RequestOptionType,
19+
)
1620

1721
LOGGER = logging.getLogger(__name__)
1822

@@ -285,3 +289,106 @@ def test_get_signed_token_with_passphrase_protected_key(self):
285289
assert decoded_payload["iss"] == "test_issuer"
286290
assert "iat" in decoded_payload
287291
assert "exp" in decoded_payload
292+
293+
@pytest.mark.parametrize(
294+
"request_option, expected_request_key",
295+
[
296+
pytest.param(
297+
RequestOption(
298+
inject_into=RequestOptionType.request_parameter,
299+
field_name="custom_parameter",
300+
parameters={},
301+
),
302+
"custom_parameter",
303+
id="test_get_request_headers",
304+
),
305+
pytest.param(
306+
RequestOption(
307+
inject_into=RequestOptionType.body_data, field_name="custom_body", parameters={}
308+
),
309+
"custom_body",
310+
id="test_get_request_headers",
311+
),
312+
pytest.param(
313+
RequestOption(
314+
inject_into=RequestOptionType.body_json, field_name="custom_json", parameters={}
315+
),
316+
"custom_json",
317+
id="test_get_request_headers",
318+
),
319+
],
320+
)
321+
def test_get_request_options(self, request_option, expected_request_key):
322+
authenticator = JwtAuthenticator(
323+
config={},
324+
parameters={},
325+
algorithm="HS256",
326+
secret_key="test_key",
327+
token_duration=1000,
328+
iss="test_iss",
329+
sub="test_sub",
330+
aud="test_aud",
331+
additional_jwt_payload={"kid": "test_kid"},
332+
request_option=request_option,
333+
)
334+
335+
expected_request_options = {
336+
expected_request_key: jwt.encode(
337+
payload=authenticator._get_jwt_payload(),
338+
key=authenticator._get_secret_key(),
339+
algorithm=authenticator._algorithm,
340+
headers=authenticator._get_jwt_headers(),
341+
)
342+
}
343+
344+
match request_option.inject_into:
345+
case RequestOptionType.request_parameter:
346+
actual_request_options = authenticator.get_request_params()
347+
case RequestOptionType.body_data:
348+
actual_request_options = authenticator.get_request_body_data()
349+
case RequestOptionType.body_json:
350+
actual_request_options = authenticator.get_request_body_json()
351+
case _:
352+
actual_request_options = None
353+
354+
assert actual_request_options == expected_request_options
355+
356+
@pytest.mark.parametrize(
357+
"request_option, expected_header_key",
358+
[
359+
pytest.param(
360+
RequestOption(
361+
inject_into=RequestOptionType.header,
362+
field_name="custom_authorization",
363+
parameters={},
364+
),
365+
"custom_authorization",
366+
id="test_get_request_headers",
367+
),
368+
pytest.param(None, "Authorization", id="test_with_default_authorization_header"),
369+
],
370+
)
371+
def test_get_request_headers(self, request_option, expected_header_key):
372+
authenticator = JwtAuthenticator(
373+
config={},
374+
parameters={},
375+
algorithm="HS256",
376+
secret_key="test_key",
377+
token_duration=1000,
378+
iss="test_iss",
379+
sub="test_sub",
380+
aud="test_aud",
381+
additional_jwt_payload={"kid": "test_kid"},
382+
request_option=request_option,
383+
)
384+
385+
expected_headers = {
386+
expected_header_key: jwt.encode(
387+
payload=authenticator._get_jwt_payload(),
388+
key=authenticator._get_secret_key(),
389+
algorithm=authenticator._algorithm,
390+
headers=authenticator._get_jwt_headers(),
391+
)
392+
}
393+
394+
assert authenticator.get_auth_header() == expected_headers

unit_tests/sources/declarative/parsers/test_model_to_component_factory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3034,6 +3034,10 @@ def test_create_custom_retriever():
30343034
aud: "test aud"
30353035
additional_jwt_payload:
30363036
test: "test custom payload"
3037+
request_option:
3038+
type: RequestOption
3039+
inject_into: body_json
3040+
field_name: authorization
30373041
""",
30383042
{
30393043
"secret_key": "secret_key",
@@ -3141,6 +3145,11 @@ def test_create_jwt_authenticator(config, manifest, expected):
31413145
)
31423146
assert authenticator._get_jwt_payload() == jwt_payload
31433147

3148+
if authenticator_manifest.get("request_option"):
3149+
assert authenticator._request_option.inject_into.value == authenticator_manifest.get(
3150+
"request_option", {}
3151+
).get("inject_into")
3152+
31443153

31453154
def test_use_request_options_provider_for_datetime_based_cursor():
31463155
config = {

0 commit comments

Comments
 (0)