-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoci_models.py
More file actions
141 lines (113 loc) · 3.91 KB
/
oci_models.py
File metadata and controls
141 lines (113 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
File name: oci_models.py
Author: Luigi Saetta
Date last modified: 2025-06-30
Python Version: 3.11
Description:
This module enables easy access to OCI GenAI LLM/Embeddings.
Usage:
Import this module into other scripts to use its functions.
Example:
from oci_models import get_llm
License:
This code is released under the MIT License.
Notes:
This is a part of a demo showing how to implement an advanced
RAG solution as a LangGraph agent.
modified to support xAI and OpenAI models through Langchain
Warnings:
This module is in development, may change in future versions.
"""
# switched to the new OCI langchain integration
from langchain_oci import ChatOCIGenAI, OCIGenAIEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_oracledb.vectorstores import OracleVS
from custom_rest_embeddings import CustomRESTEmbeddings
from utils import get_console_logger
from config import (
USE_LANGCHAIN_OPENAI,
STREAMING,
AUTH,
SERVICE_ENDPOINT,
# used only for defaults
LLM_MODEL_ID,
TEMPERATURE,
MAX_TOKENS,
EMBED_MODEL_ID,
NVIDIA_EMBED_MODEL_URL,
)
from config_private import COMPARTMENT_ID
logger = get_console_logger()
ALLOWED_EMBED_MODELS_TYPE = {"OCI", "NVIDIA"}
# for gpt5, since max tokens is not supported
MODELS_WITHOUT_KWARGS = {
"openai.gpt-oss-120b",
"openai.gpt-5.2",
"openai.gpt-4o-search-preview",
"openai.gpt-4o-search-preview-2025-03-11",
}
def get_llm(model_id=LLM_MODEL_ID, temperature=TEMPERATURE, max_tokens=MAX_TOKENS):
"""
Initialize and return an instance of ChatOCIGenAI with the specified configuration.
Returns:
ChatOCIGenAI: An instance of the OCI GenAI language model.
"""
if model_id not in MODELS_WITHOUT_KWARGS:
_model_kwargs = {"temperature": temperature, "max_tokens": max_tokens}
else:
# for some models (OpenAI search) you cannot set those params
_model_kwargs = None
if not USE_LANGCHAIN_OPENAI:
llm = ChatOCIGenAI(
auth_type=AUTH,
model_id=model_id,
service_endpoint=SERVICE_ENDPOINT,
compartment_id=COMPARTMENT_ID,
is_stream=STREAMING,
model_kwargs=_model_kwargs,
)
else:
logger.info("Using Langchain OpenAI integration for OCI models...")
logger.info("Model ID: %s...", model_id)
raise NotImplementedError("Feature not implemented yet...")
return llm
def get_embedding_model(model_type="OCI"):
"""
Initialize and return an instance of OCIGenAIEmbeddings with the specified configuration.
Returns:
OCIGenAIEmbeddings: An instance of the OCI GenAI embeddings model.
"""
# check model type
if model_type not in ALLOWED_EMBED_MODELS_TYPE:
raise ValueError(
f"Invalid value for model_type: must be one of {ALLOWED_EMBED_MODELS_TYPE}"
)
embed_model = None
if model_type == "OCI":
embed_model = OCIGenAIEmbeddings(
auth_type=AUTH,
model_id=EMBED_MODEL_ID,
service_endpoint=SERVICE_ENDPOINT,
compartment_id=COMPARTMENT_ID,
)
elif model_type == "NVIDIA":
embed_model = CustomRESTEmbeddings(
api_url=NVIDIA_EMBED_MODEL_URL, model=EMBED_MODEL_ID
)
logger.info("Embedding model is: %s", EMBED_MODEL_ID)
return embed_model
def get_oracle_vs(conn, collection_name, embed_model):
"""
Initialize and return an instance of OracleVS for vector search.
Args:
conn: The database connection object.
collection_name (str): The name of the collection (DB table) to search in.
embed_model: The embedding model to use for vector search.
"""
oracle_vs = OracleVS(
client=conn,
table_name=collection_name,
distance_strategy=DistanceStrategy.COSINE,
embedding_function=embed_model,
)
return oracle_vs