Skip to content

Commit 1def3c3

Browse files
committed
feat: support custom boto3 client
1 parent 8ea89bb commit 1def3c3

File tree

3 files changed

+36
-15
lines changed

3 files changed

+36
-15
lines changed

pydantic_ssm_settings/settings.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Optional, Tuple, Type
2+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type
33

44
from pydantic_settings import (
55
BaseSettings,
@@ -12,11 +12,15 @@
1212

1313
from .source import AwsSsmSettingsSource
1414

15+
if TYPE_CHECKING:
16+
from mypy_boto3_ssm.client import SSMClient
17+
1518
logger = logging.getLogger(__name__)
1619

1720

1821
class SsmSettingsConfigDict(SettingsConfigDict):
1922
ssm_prefix: str
23+
ssm_client: Optional["SSMClient"]
2024

2125

2226
class AwsSsmBaseSettings(BaseSettings):
@@ -29,6 +33,7 @@ def __init__(
2933
self,
3034
*args,
3135
_ssm_prefix: Optional[str] = None,
36+
_ssm_client: Optional["SSMClient"] = None,
3237
**kwargs: Any,
3338
) -> None:
3439
"""
@@ -37,9 +42,12 @@ def __init__(
3742
separated by "/". NB:unlike its _env_prefix counterpart, _ssm_prefix
3843
is treated case sensitively regardless of the _case_sensitive
3944
parameter value.
45+
_ssm_client: Optional boto3 SSM client. If not provided, a new client
46+
will be created.
4047
"""
4148
# NOTE: Need a direct access to the attributes dictionary to avoid raising an AttributeError: __pydantic_private__ exception
4249
self.__dict__["__ssm_prefix"] = _ssm_prefix
50+
self.__dict__["__ssm_client"] = _ssm_client
4351
super().__init__(self, *args, **kwargs)
4452

4553
def settings_customise_sources(
@@ -53,6 +61,7 @@ def settings_customise_sources(
5361
ssm_settings = AwsSsmSettingsSource(
5462
settings_cls=settings_cls,
5563
ssm_prefix=self.__dict__["__ssm_prefix"],
64+
ssm_client=self.__dict__["__ssm_client"],
5665
)
5766

5867
return (

pydantic_ssm_settings/source.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,31 @@ def __init__(
3131
settings_cls: type[BaseSettings],
3232
case_sensitive: Optional[bool] = None,
3333
ssm_prefix: Optional[str] = None,
34+
ssm_client: Optional["SSMClient"] = None,
3435
):
35-
# Ideally would retrieve ssm_prefix from self.config
36-
# but need the superclass to be initialized for that
37-
ssm_prefix_ = (
36+
ssm_prefix = (
3837
ssm_prefix
3938
if ssm_prefix is not None
4039
else settings_cls.model_config.get("ssm_prefix", "/")
4140
)
41+
self.ssm_client = (
42+
ssm_client
43+
if ssm_client
44+
else settings_cls.model_config.get("ssm_client", self._build_client())
45+
)
4246
super().__init__(
4347
settings_cls,
4448
case_sensitive=case_sensitive,
45-
env_prefix=ssm_prefix_,
49+
env_prefix=ssm_prefix,
4650
env_nested_delimiter="/", # SSM only accepts / as a delimiter
4751
)
48-
self.ssm_prefix = ssm_prefix_
49-
assert self.ssm_prefix == self.env_prefix
50-
51-
@property
52-
def client(self) -> "SSMClient":
53-
return boto3.client("ssm", config=self.client_config)
52+
assert ssm_prefix == self.env_prefix
5453

55-
@property
56-
def client_config(self) -> Config:
54+
def _build_client(self) -> "SSMClient":
5755
timeout = float(os.environ.get("SSM_TIMEOUT", 0.5))
58-
return Config(connect_timeout=timeout, read_timeout=timeout)
56+
return boto3.client(
57+
"ssm", config=Config(connect_timeout=timeout, read_timeout=timeout)
58+
)
5959

6060
def _load_env_vars(
6161
self,
@@ -68,7 +68,7 @@ def _load_env_vars(
6868

6969
output = {}
7070
try:
71-
paginator = self.client.get_paginator("get_parameters_by_path")
71+
paginator = self.ssm_client.get_paginator("get_parameters_by_path")
7272
response_iterator = paginator.paginate(
7373
Path=self.env_prefix, WithDecryption=True, Recursive=True
7474
)

tests/test_main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import boto3
2+
import moto
13
import pytest
24
from pydantic import BaseModel
35
from pydantic_settings import SettingsConfigDict
@@ -116,3 +118,13 @@ def test_parameters_from_model_config(ssm):
116118
ssm.put_parameter(Name="/asdf/foo", Value="bar", Type="String")
117119
settings = CustomConfigDict()
118120
assert settings.foo == "bar"
121+
122+
123+
@pytest.mark.parametrize("region", ["us-east-1", "us-west-2"])
124+
def test_custom_client(region: str):
125+
with moto.mock_aws():
126+
client = boto3.client("ssm", region_name=region)
127+
client.put_parameter(Name="/foo", Value=region, Type="String")
128+
129+
settings = SimpleSettings(_ssm_client=client)
130+
assert settings.foo == region

0 commit comments

Comments
 (0)