Skip to content

Commit b7229a3

Browse files
Nova models addition (#1006)
* amazon nova models added, titan embeddings added * example env added with nova model config
1 parent f2a800d commit b7229a3

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

backend/example.env

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ YOUTUBE_TRANSCRIPT_PROXY="https://user:pass@domain:port"
4545
EFFECTIVE_SEARCH_RATIO=5
4646
GRAPH_CLEANUP_MODEL="openai_gpt_4o"
4747
CHUNKS_TO_BE_PROCESSED="50"
48+
BEDROCK_EMBEDDING_MODEL="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.titan-embed-text-v1"
49+
LLM_MODEL_CONFIG_bedrock_nova_micro_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-micro-v1:0"
50+
LLM_MODEL_CONFIG_bedrock_nova_lite_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-lite-v1:0"
51+
LLM_MODEL_CONFIG_bedrock_nova_pro_v1="model_name,aws_access_key,aws_secret_key,region_name" #model_name="amazon.nova-pro-v1:0"

backend/src/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_llm(model: str):
8989
)
9090

9191
llm = ChatBedrock(
92-
client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0)
92+
client=bedrock_client,region_name=region_name, model_id=model_name, model_kwargs=dict(temperature=0)
9393
)
9494

9595
elif "ollama" in model:

backend/src/shared/common_fn.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import os
1212
from pathlib import Path
1313
from urllib.parse import urlparse
14-
14+
import boto3
15+
from langchain_community.embeddings import BedrockEmbeddings
1516

1617
def check_url_source(source_type, yt_url:str=None, wiki_query:str=None):
1718
language=''
@@ -77,6 +78,10 @@ def load_embedding_model(embedding_model_name: str):
7778
)
7879
dimension = 768
7980
logging.info(f"Embedding: Using Vertex AI Embeddings , Dimension:{dimension}")
81+
elif embedding_model_name == "titan":
82+
embeddings = get_bedrock_embeddings()
83+
dimension = 1536
84+
logging.info(f"Embedding: Using bedrock titan Embeddings , Dimension:{dimension}")
8085
else:
8186
embeddings = HuggingFaceEmbeddings(
8287
model_name="all-MiniLM-L6-v2"#, cache_folder="/embedding_model"
@@ -134,4 +139,38 @@ def last_url_segment(url):
134139
parsed_url = urlparse(url)
135140
path = parsed_url.path.strip("/") # Remove leading and trailing slashes
136141
last_url_segment = path.split("/")[-1] if path else parsed_url.netloc.split(".")[0]
137-
return last_url_segment
142+
return last_url_segment
143+
144+
def get_bedrock_embeddings():
145+
"""
146+
Creates and returns a BedrockEmbeddings object using the specified model name.
147+
Args:
148+
model (str): The name of the model to use for embeddings.
149+
Returns:
150+
BedrockEmbeddings: An instance of the BedrockEmbeddings class.
151+
"""
152+
try:
153+
env_value = os.getenv("BEDROCK_EMBEDDING_MODEL")
154+
if not env_value:
155+
raise ValueError("Environment variable 'BEDROCK_EMBEDDING_MODEL' is not set.")
156+
try:
157+
model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",")
158+
except ValueError:
159+
raise ValueError(
160+
"Environment variable 'BEDROCK_EMBEDDING_MODEL' is improperly formatted. "
161+
"Expected format: 'model_name,aws_access_key,aws_secret_key,region_name'."
162+
)
163+
bedrock_client = boto3.client(
164+
service_name="bedrock-runtime",
165+
region_name=region_name.strip(),
166+
aws_access_key_id=aws_access_key.strip(),
167+
aws_secret_access_key=aws_secret_key.strip(),
168+
)
169+
bedrock_embeddings = BedrockEmbeddings(
170+
model_id=model_name.strip(),
171+
client=bedrock_client
172+
)
173+
return bedrock_embeddings
174+
except Exception as e:
175+
print(f"An unexpected error occurred: {e}")
176+
raise

0 commit comments

Comments
 (0)