@@ -58,9 +58,22 @@ def conform_query(query: str, provider: str) -> dict:
5858
5959
6060class BedrockEmbeddingConfig (EmbeddingConfig ):
61- aws_access_key_id : SecretStr = Field (description = "aws access key id" )
62- aws_secret_access_key : SecretStr = Field (description = "aws secret access key" )
63- region_name : str = Field (description = "aws region name" , default = "us-west-2" )
61+ aws_access_key_id : SecretStr | None = Field (description = "aws access key id" , default = None )
62+ aws_secret_access_key : SecretStr | None = Field (
63+ description = "aws secret access key" , default = None
64+ )
65+ region_name : str = Field (
66+ description = "aws region name" ,
67+ default_factory = lambda : (
68+ os .getenv ("BEDROCK_REGION_NAME" ) or
69+ os .getenv ("AWS_DEFAULT_REGION" ) or
70+ "us-west-2"
71+ )
72+ )
73+ endpoint_url : str | None = Field (description = "custom bedrock endpoint url" , default = None )
74+ access_method : str = Field (
75+ description = "authentication method" , default = "credentials"
76+ ) # "credentials" or "iam"
6477 embedder_model_name : str = Field (
6578 default = "amazon.titan-embed-text-v1" ,
6679 alias = "model_name" ,
@@ -96,6 +109,20 @@ def wrap_error(self, e: Exception) -> Exception:
96109 return e
97110
98111 def run_precheck (self ) -> None :
112+ # Validate access method and credentials configuration
113+ if self .access_method == "credentials" :
114+ if not (self .aws_access_key_id and self .aws_secret_access_key ):
115+ raise ValueError (
116+ "Credentials access method requires aws_access_key_id and aws_secret_access_key"
117+ )
118+ elif self .access_method == "iam" :
119+ # For IAM, credentials are handled by AWS SDK
120+ pass
121+ else :
122+ raise ValueError (
123+ f"Invalid access_method: { self .access_method } . Must be 'credentials' or 'iam'"
124+ )
125+
99126 client = self .get_bedrock_client ()
100127 try :
101128 model_info = client .list_foundation_models (byOutputModality = "EMBEDDING" )
@@ -113,11 +140,30 @@ def run_precheck(self) -> None:
113140 raise self .wrap_error (e = e )
114141
115142 def get_client_kwargs (self ) -> dict :
116- return {
117- "aws_access_key_id" : self .aws_access_key_id .get_secret_value (),
118- "aws_secret_access_key" : self .aws_secret_access_key .get_secret_value (),
143+ kwargs = {
119144 "region_name" : self .region_name ,
120145 }
146+
147+ if self .endpoint_url :
148+ kwargs ["endpoint_url" ] = self .endpoint_url
149+
150+ if self .access_method == "credentials" :
151+ if self .aws_access_key_id and self .aws_secret_access_key :
152+ kwargs ["aws_access_key_id" ] = self .aws_access_key_id .get_secret_value ()
153+ kwargs ["aws_secret_access_key" ] = self .aws_secret_access_key .get_secret_value ()
154+ else :
155+ raise ValueError (
156+ "Credentials access method requires aws_access_key_id and aws_secret_access_key"
157+ )
158+ elif self .access_method == "iam" :
159+ # For IAM, boto3 will use default credential chain (IAM roles, environment, etc.)
160+ pass
161+ else :
162+ raise ValueError (
163+ f"Invalid access_method: { self .access_method } . Must be 'credentials' or 'iam'"
164+ )
165+
166+ return kwargs
121167
122168 @requires_dependencies (
123169 ["boto3" ],
0 commit comments