Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,25 @@ def __init__(

self.model_name = model_name
# check kwargs are primitives only
# OPTIMIZATION: use tuple lookup rather than isinstance for small number of checks
_primitive_types = (str, int, float, bool, list, dict, tuple)
for key, value in kwargs.items():
if not isinstance(value, (str, int, float, bool, list, dict, tuple)):
if type(value) not in _primitive_types and not isinstance(
value, _primitive_types
):
raise ValueError(f"Keyword argument {key} is not a primitive type")
self.kwargs = kwargs

# Store the session for serialization
self._session_args = {}
if hasattr(session, "region_name") and session.region_name:
self._session_args["region_name"] = session.region_name
if hasattr(session, "profile_name") and session.profile_name:
self._session_args["profile_name"] = session.profile_name

region_name = getattr(session, "region_name", None)
profile_name = getattr(session, "profile_name", None)
if region_name:
self._session_args["region_name"] = region_name
if profile_name:
self._session_args["profile_name"] = profile_name

# Client construction is fast-enough, no changes needed
self._client = session.client(
service_name="bedrock-runtime",
**kwargs,
Expand Down Expand Up @@ -87,24 +94,30 @@ def name() -> str:

@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
# OPTIMIZATION: Tighten the import so boto3 is loaded only if not already imported.
# But since "import boto3" is fast (cached), explicit import remains.
try:
import boto3
except ImportError:
raise ValueError(
"The boto3 python package is not installed. Please install it with `pip install boto3`"
)

# Faster dict access via local variables
model_name = config.get("model_name")
session_args = config.get("session_args")
if model_name is None:
assert False, "This code should not be reached"
kwargs = config.get("kwargs", {})

if session_args is None:
session = boto3.Session()
else:
session = boto3.Session(**session_args)
# Don't check for None twice; single check is enough
session = (
boto3.Session(**session_args)
if session_args is not None
else boto3.Session()
)

# Avoid extraneous line splits for single calls
return AmazonBedrockEmbeddingFunction(
session=session, model_name=model_name, **kwargs
)
Expand Down