1- import os
1+ from __future__ import annotations as _annotations
2+
23import logging
4+ import os
35from pathlib import Path
4- from typing import TYPE_CHECKING , Any , Dict , Mapping , Optional , Tuple
6+ from typing import TYPE_CHECKING , Any
57
6- from botocore .exceptions import ClientError
7- from botocore .client import Config
88import boto3
9-
10- from pydantic import BaseSettings
11- from pydantic .typing import StrPath , get_origin , is_union
12- from pydantic .utils import deep_update
13- from pydantic .fields import ModelField
9+ from botocore .client import Config
10+ from botocore .exceptions import ClientError
11+ from pydantic import BaseModel
12+ from pydantic ._internal ._utils import lenient_issubclass
13+ from pydantic .fields import FieldInfo
14+ from pydantic_settings import BaseSettings
15+ from pydantic_settings .sources import (
16+ EnvSettingsSource ,
17+ )
1418
1519if TYPE_CHECKING :
1620 from mypy_boto3_ssm .client import SSMClient
@@ -23,16 +27,28 @@ class SettingsError(ValueError):
2327 pass
2428
2529
26- class AwsSsmSettingsSource :
27- __slots__ = ("ssm_prefix" , "env_nested_delimiter" )
28-
30+ class AwsSsmSettingsSource (EnvSettingsSource ):
2931 def __init__ (
3032 self ,
31- ssm_prefix : Optional [StrPath ],
32- env_nested_delimiter : Optional [str ] = None ,
33+ settings_cls : type [BaseSettings ],
34+ case_sensitive : bool = None ,
35+ ssm_prefix : str = None ,
3336 ):
34- self .ssm_prefix : Optional [StrPath ] = ssm_prefix
35- self .env_nested_delimiter : Optional [str ] = env_nested_delimiter
37+ # Ideally would retrieve ssm_prefix from self.config
38+ # but need the superclass to be initialized for that
39+ ssm_prefix_ = (
40+ ssm_prefix
41+ if ssm_prefix is not None
42+ else settings_cls .model_config .get ("ssm_prefix" , "/" )
43+ )
44+ super ().__init__ (
45+ settings_cls ,
46+ case_sensitive = case_sensitive ,
47+ env_prefix = ssm_prefix_ ,
48+ env_nested_delimiter = "/" , # SSM only accepts / as a delimiter
49+ )
50+ self .ssm_prefix = ssm_prefix_
51+ assert self .ssm_prefix == self .env_prefix
3652
3753 @property
3854 def client (self ) -> "SSMClient" :
@@ -43,124 +59,103 @@ def client_config(self) -> Config:
4359 timeout = float (os .environ .get ("SSM_TIMEOUT" , 0.5 ))
4460 return Config (connect_timeout = timeout , read_timeout = timeout )
4561
46- def load_from_ssm (self , secrets_path : Path , case_sensitive : bool ):
47-
48- if not secrets_path .is_absolute ():
62+ def _load_env_vars (
63+ self ,
64+ ):
65+ """
66+ Access env_prefix instead of ssm_prefix
67+ """
68+ if not Path (self .env_prefix ).is_absolute ():
4969 raise ValueError ("SSM prefix must be absolute path" )
5070
51- logger .debug (f"Building SSM settings with prefix of { secrets_path = } " )
71+ logger .debug (f"Building SSM settings with prefix of { self . env_prefix = } " )
5272
5373 output = {}
5474 try :
5575 paginator = self .client .get_paginator ("get_parameters_by_path" )
5676 response_iterator = paginator .paginate (
57- Path = str ( secrets_path ) , WithDecryption = True
77+ Path = self . env_prefix , WithDecryption = True , Recursive = True
5878 )
5979
6080 for page in response_iterator :
6181 for parameter in page ["Parameters" ]:
62- key = Path (parameter ["Name" ]).relative_to (secrets_path ).as_posix ()
63- output [key if case_sensitive else key .lower ()] = parameter ["Value" ]
82+ key = (
83+ Path (parameter ["Name" ]).relative_to (self .env_prefix ).as_posix ()
84+ )
85+ output [
86+ self .env_prefix + key
87+ if self .case_sensitive
88+ else self .env_prefix .lower () + key .lower ()
89+ ] = parameter ["Value" ]
6490
6591 except ClientError :
66- logger .exception ("Failed to get parameters from %s" , secrets_path )
92+ logger .exception ("Failed to get parameters from %s" , self . env_prefix )
6793
6894 return output
6995
70- def __call__ (self , settings : BaseSettings ) -> Dict [str , Any ]:
71- """
72- Returns SSM values for all settings.
73- """
74- d : Dict [str , Optional [Any ]] = {}
75-
76- if self .ssm_prefix is None :
77- return d
78-
79- ssm_values = self .load_from_ssm (
80- secrets_path = Path (self .ssm_prefix ),
81- case_sensitive = settings .__config__ .case_sensitive ,
82- )
96+ def __repr__ (self ) -> str :
97+ return f"AwsSsmSettingsSource(ssm_prefix={ self .env_prefix !r} )"
8398
84- # The following was lifted from https://github.com/samuelcolvin/pydantic/blob/a21f0763ee877f0c86f254a5d60f70b1002faa68/pydantic/env_settings.py#L165-L237 # noqa
85- for field in settings .__fields__ .values ():
86- env_val : Optional [str ] = None
87- for env_name in field .field_info .extra ["env_names" ]:
88- env_val = ssm_values .get (env_name )
89- if env_val is not None :
90- break
91-
92- is_complex , allow_json_failure = self ._field_is_complex (field )
93- if is_complex :
94- if env_val is None :
95- # field is complex but no value found so far, try explode_env_vars
96- env_val_built = self ._explode_ssm_values (field , ssm_values )
97- if env_val_built :
98- d [field .alias ] = env_val_built
99- else :
100- # field is complex and there's a value, decode that as JSON, then
101- # add explode_env_vars
102- try :
103- env_val = settings .__config__ .json_loads (env_val )
104- except ValueError as e :
105- if not allow_json_failure :
106- raise SettingsError (
107- f'error parsing JSON for "{ env_name } "'
108- ) from e
109-
110- if isinstance (env_val , dict ):
111- d [field .alias ] = deep_update (
112- env_val , self ._explode_ssm_values (field , ssm_values )
113- )
114- else :
115- d [field .alias ] = env_val
116- elif env_val is not None :
117- # simplest case, field is not complex, we only need to add the
118- # value if it was found
119- d [field .alias ] = env_val
120-
121- return d
122-
123- def _field_is_complex (self , field : ModelField ) -> Tuple [bool , bool ]:
124- """
125- Find out if a field is complex, and if so whether JSON errors should be ignored
99+ def get_field_value (
100+ self , field : FieldInfo , field_name : str
101+ ) -> tuple [Any , str , bool ]:
126102 """
127- if field .is_complex ():
128- allow_json_failure = False
129- elif (
130- is_union (get_origin (field .type_ ))
131- and field .sub_fields
132- and any (f .is_complex () for f in field .sub_fields )
133- ):
134- allow_json_failure = True
135- else :
136- return False , False
103+ Gets the value for field from environment variables and a flag to
104+ determine whether value is complex.
137105
138- return True , allow_json_failure
106+ Args:
107+ field: The field.
108+ field_name: The field name.
139109
140- def _explode_ssm_values (
141- self , field : ModelField , env_vars : Mapping [ str , Optional [ str ]]
142- ) -> Dict [ str , Any ]:
110+ Returns:
111+ A tuple contains the key, value if the file exists otherwise `None`, and
112+ a flag to determine whether value is complex.
143113 """
144- Process env_vars and extract the values of keys containing
145- env_nested_delimiter into nested dictionaries.
146114
147- This is applied to a single field, hence filtering by env_var prefix.
148- """
149- prefixes = [
150- f"{ env_name } { self .env_nested_delimiter } "
151- for env_name in field .field_info .extra ["env_names" ]
152- ]
153- result : Dict [str , Any ] = {}
154- for env_name , env_val in env_vars .items ():
155- if not any (env_name .startswith (prefix ) for prefix in prefixes ):
156- continue
157- _ , * keys , last_key = env_name .split (self .env_nested_delimiter )
158- env_var = result
159- for key in keys :
160- env_var = env_var .setdefault (key , {})
161- env_var [last_key ] = env_val
162-
163- return result
115+ # env_name = /asdf/foo
116+ # env_vars = {foo:xyz}
117+ env_val : str | None = None
118+ for field_key , env_name , value_is_complex in self ._extract_field_info (
119+ field , field_name
120+ ):
121+ env_val = self .env_vars .get (env_name )
122+ if env_val is not None :
123+ break
124+
125+ return env_val , field_key , value_is_complex
126+
127+ def __call__ (self ) -> dict [str , Any ]:
128+ data : dict [str , Any ] = {}
129+
130+ for field_name , field in self .settings_cls .model_fields .items ():
131+ try :
132+ field_value , field_key , value_is_complex = self .get_field_value (
133+ field , field_name
134+ )
135+ except Exception as e :
136+ raise SettingsError (
137+ f'error getting value for field "{ field_name } " from source "{ self .__class__ .__name__ } "' # noqa
138+ ) from e
139+
140+ try :
141+ field_value = self .prepare_field_value (
142+ field_name , field , field_value , value_is_complex
143+ )
144+ except ValueError as e :
145+ raise SettingsError (
146+ f'error parsing value for field "{ field_name } " from source "{ self .__class__ .__name__ } "' # noqa
147+ ) from e
148+
149+ if field_value is not None :
150+ if (
151+ not self .case_sensitive
152+ and lenient_issubclass (field .annotation , BaseModel )
153+ and isinstance (field_value , dict )
154+ ):
155+ data [field_key ] = self ._replace_field_names_case_insensitively (
156+ field , field_value
157+ )
158+ else :
159+ data [field_key ] = field_value
164160
165- def __repr__ (self ) -> str :
166- return f"AwsSsmSettingsSource(ssm_prefix={ self .ssm_prefix !r} )"
161+ return data
0 commit comments