Skip to content

Commit 811105e

Browse files
kwargs only to aditional params
1 parent 9368f1f commit 811105e

File tree

3 files changed

+17
-18
lines changed

3 files changed

+17
-18
lines changed

docs/ai_utils.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ completely different).
3535
threshold = 0.8 # Similarity threshold between 0 and 1
3636
similarity_method = 'spacy' # Options: 'spacy', 'sentence_transformers', 'openai', 'azure_openai'
3737
model_name = 'en_core_web_md' # Model name to use for the selected method
38-
azure = True # Set to True if using Azure OpenAI (only for 'openai' method). False by default, to use OpenAI directly. For azure_openai method, True by default.
38+
**kwargs = {} # Additional parameters for the selected method.
3939
4040
# Validate similarity
4141
assert_text_similarity(input_text, expected_text, threshold=threshold, similarity_method=similarity_method)

toolium/utils/ai_utils/openai.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,23 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35-
def openai_request(system_message, user_message, model_name=None, **kwargs):
35+
def openai_request(system_message, user_message, model_name=None, azure=False, **kwargs):
3636
"""
3737
Make a request to OpenAI API (Azure or standard)
3838
3939
:param system_message: system message to set the behavior of the assistant
4040
:param user_message: user message with the request
4141
:param model_name: name of the model to use
42-
:param kwargs: additional parameters, including:
43-
- azure: whether to use Azure OpenAI or standard OpenAI
42+
:param azure: whether to use Azure OpenAI or standard OpenAI
43+
:param kwargs: additional parameters to be passed to the OpenAI client (azure_endpoint, timeout, etc.)
4444
:returns: response from OpenAI
4545
"""
4646
if OpenAI is None:
4747
raise ImportError("OpenAI is not installed. Please run 'pip install toolium[ai]' to use OpenAI features")
4848
config = DriverWrappersPool.get_default_wrapper().config
4949
model_name = model_name or config.get_optional('AI', 'openai_model', 'gpt-4o-mini')
5050
logger.info(f"Calling to OpenAI API with model {model_name}")
51-
client = AzureOpenAI() if kwargs.get('azure', False) else OpenAI()
51+
client = AzureOpenAI(**kwargs) if azure else OpenAI(**kwargs)
5252
completion = client.chat.completions.create(
5353
model=model_name,
5454
messages=[

toolium/utils/ai_utils/text_similarity.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None, **kwarg
4141
:param text: string to compare
4242
:param expected_text: string with the expected text
4343
:param model_name: name of the spaCy model to use
44-
:param kwargs: additional parameters
44+
:param kwargs: additional parameters to be used by spaCy (disable, exclude, etc.)
4545
:returns: similarity score between the two texts
4646
"""
4747
# NOTE: spaCy similarity performance can be enhanced using some strategies like:
@@ -50,7 +50,7 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None, **kwarg
5050
# - Preprocessing texts. Now we only preprocess negations.
5151
config = DriverWrappersPool.get_default_wrapper().config
5252
model_name = model_name or config.get_optional('AI', 'spacy_model', 'en_core_web_md')
53-
model = get_spacy_model(model_name)
53+
model = get_spacy_model(model_name, **kwargs)
5454
if model is None:
5555
raise ImportError("spaCy is not installed. Please run 'pip install toolium[ai]' to use spaCy features")
5656
text = model(preprocess_with_ud_negation(text, model))
@@ -67,31 +67,31 @@ def get_text_similarity_with_sentence_transformers(text, expected_text, model_na
6767
:param text: string to compare
6868
:param expected_text: string with the expected text
6969
:param model_name: name of the Sentence Transformers model to use
70-
:param kwargs: additional parameters
70+
:param kwargs: additional parameters to be used by SentenceTransformer (modules, device, prompts, etc.)
7171
:returns: similarity score between the two texts
7272
"""
7373
if SentenceTransformer is None:
7474
raise ImportError("Sentence Transformers is not installed. Please run 'pip install toolium[ai]'"
7575
" to use Sentence Transformers features")
7676
config = DriverWrappersPool.get_default_wrapper().config
7777
model_name = model_name or config.get_optional('AI', 'sentence_transformers_model', 'all-mpnet-base-v2')
78-
model = SentenceTransformer(model_name)
78+
model = SentenceTransformer(model_name, **kwargs)
7979
similarity = float(model.similarity(model.encode(expected_text), model.encode(text)))
8080
# similarity can be slightly > 1 due to float precision
8181
similarity = 1 if similarity > 1 else similarity
8282
logger.info(f"Sentence Transformers similarity: {similarity} between '{text}' and '{expected_text}'")
8383
return similarity
8484

8585

86-
def get_text_similarity_with_openai(text, expected_text, model_name=None, **kwargs):
86+
def get_text_similarity_with_openai(text, expected_text, model_name=None, azure=False, **kwargs):
8787
"""
8888
Return semantic similarity between two texts using OpenAI LLM
8989
9090
:param text: string to compare
9191
:param expected_text: string with the expected text
9292
:param model_name: name of the OpenAI model to use
93-
:param kwargs: additional parameters including:
94-
- azure: whether to use Azure OpenAI or standard OpenAI
93+
:param azure: whether to use Azure OpenAI or standard OpenAI
94+
:param kwargs: additional parameters to be used by OpenAI client
9595
:returns: tuple with similarity score between the two texts and explanation
9696
"""
9797
system_message = (
@@ -106,7 +106,7 @@ def get_text_similarity_with_openai(text, expected_text, model_name=None, **kwar
106106
f"The expected answer is: {expected_text}."
107107
f" The LLM answer is: {text}."
108108
)
109-
response = openai_request(system_message, user_message, model_name, **kwargs)
109+
response = openai_request(system_message, user_message, model_name, azure, **kwargs)
110110
try:
111111
response = json.loads(response)
112112
similarity = float(response['similarity'])
@@ -124,12 +124,11 @@ def get_text_similarity_with_azure_openai(text, expected_text, model_name=None,
124124
125125
:param text: string to compare
126126
:param expected_text: string with the expected text
127-
:param model_name: name of the OpenAI model to use
128-
:param kwargs: additional parameters
127+
:param model_name: name of the Azure OpenAI model to use
128+
:param kwargs: additional parameters to be used by Azure OpenAI client
129129
:returns: tuple with similarity score between the two texts and explanation
130130
"""
131-
kwargs["azure"] = True
132-
return get_text_similarity_with_openai(text, expected_text, model_name, **kwargs)
131+
return get_text_similarity_with_openai(text, expected_text, model_name, azure=True, **kwargs)
133132

134133

135134
def assert_text_similarity(text, expected_texts, threshold, similarity_method=None, model_name=None, **kwargs):
@@ -142,7 +141,7 @@ def assert_text_similarity(text, expected_texts, threshold, similarity_method=No
142141
:param similarity_method: method to use for text comparison ('spacy', 'sentence_transformers', 'openai'
143142
or 'azure_openai')
144143
:param model_name: model name to use for the similarity method
145-
:param kwargs: additional parameters including azure flag for openai methods
144+
:param kwargs: additional parameters to be used by OpenAI methods
146145
"""
147146
config = DriverWrappersPool.get_default_wrapper().config
148147
similarity_method = similarity_method or config.get_optional('AI', 'text_similarity_method', 'spacy')

0 commit comments

Comments
 (0)