Skip to content

Commit e1d807b

Browse files
feat(ia_utils): accept kwargs at assert text similarity (#441)
* accept kwargs at assert text similarity * copilot comments * test name * lint fix * fix test * fixed model_name * kwargs only to aditional params * fix comments * kwargs at load model
1 parent bc5ae5c commit e1d807b

File tree

4 files changed

+95
-18
lines changed

4 files changed

+95
-18
lines changed

toolium/test/utils/ai_utils/test_text_similarity.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_assert_text_similarity_with_default_method(similarity_mock):
140140
input_text = 'Today it will be sunny'
141141
expected_text = 'Today is sunny'
142142
assert_text_similarity(input_text, expected_text, threshold=0.8)
143-
similarity_mock.assert_called_once_with(input_text, expected_text)
143+
similarity_mock.assert_called_once_with(input_text, expected_text, None)
144144

145145

146146
@pytest.mark.skip(reason='Sentence Transformers model is not available in the CI environment')
@@ -157,7 +157,7 @@ def test_assert_text_similarity_with_configured_method(similarity_mock):
157157
input_text = 'Today it will be sunny'
158158
expected_text = 'Today is sunny'
159159
assert_text_similarity(input_text, expected_text, threshold=0.8)
160-
similarity_mock.assert_called_once_with(input_text, expected_text)
160+
similarity_mock.assert_called_once_with(input_text, expected_text, None)
161161

162162

163163
@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy')
@@ -173,4 +173,70 @@ def test_assert_text_similarity_with_configured_and_explicit_method(similarity_m
173173
input_text = 'Today it will be sunny'
174174
expected_text = 'Today is sunny'
175175
assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='spacy')
176-
similarity_mock.assert_called_once_with(input_text, expected_text)
176+
similarity_mock.assert_called_once_with(input_text, expected_text, None)
177+
178+
179+
@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy')
180+
def test_assert_text_similarity_with_configured_and_explicit_model(similarity_mock):
181+
config = DriverWrappersPool.get_default_wrapper().config
182+
try:
183+
config.add_section('AI')
184+
except Exception:
185+
pass
186+
config.set('AI', 'text_similarity_method', 'spacy')
187+
similarity_mock.return_value = 0.9
188+
189+
input_text = 'Today it will be sunny'
190+
expected_text = 'Today is sunny'
191+
assert_text_similarity(input_text, expected_text, threshold=0.8, model_name='en_core_web_lg')
192+
similarity_mock.assert_called_once_with(input_text, expected_text, 'en_core_web_lg')
193+
194+
195+
@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_spacy')
196+
def test_assert_text_similarity_with_configured_and_explicit_method_and_model(similarity_mock):
197+
config = DriverWrappersPool.get_default_wrapper().config
198+
try:
199+
config.add_section('AI')
200+
except Exception:
201+
pass
202+
config.set('AI', 'text_similarity_method', 'sentence_transformers')
203+
similarity_mock.return_value = 0.9
204+
205+
input_text = 'Today it will be sunny'
206+
expected_text = 'Today is sunny'
207+
assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='spacy',
208+
model_name='en_core_web_lg')
209+
similarity_mock.assert_called_once_with(input_text, expected_text, 'en_core_web_lg')
210+
211+
212+
@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_openai')
213+
def test_assert_text_similarity_with_explicit_openai(similarity_mock):
214+
config = DriverWrappersPool.get_default_wrapper().config
215+
try:
216+
config.add_section('AI')
217+
except Exception:
218+
pass
219+
config.set('AI', 'spacy_model', 'en_core_web_md')
220+
similarity_mock.return_value = 0.9
221+
222+
input_text = 'Today it will be sunny'
223+
expected_text = 'Today is sunny'
224+
assert_text_similarity(input_text, expected_text, threshold=0.8, similarity_method='openai',
225+
azure=True, model_name='gpt-4o-mini')
226+
similarity_mock.assert_called_once_with(input_text, expected_text, 'gpt-4o-mini', azure=True)
227+
228+
229+
@mock.patch('toolium.utils.ai_utils.text_similarity.get_text_similarity_with_openai')
230+
def test_azure_openai_request_params(similarity_mock):
231+
config = DriverWrappersPool.get_default_wrapper().config
232+
try:
233+
config.add_section('AI')
234+
except Exception:
235+
pass
236+
config.set('AI', 'text_similarity_method', 'azure_openai')
237+
similarity_mock.return_value = 0.9
238+
239+
input_text = 'Today it will be sunny'
240+
expected_text = 'Today is sunny'
241+
assert_text_similarity(input_text, expected_text, threshold=0.8)
242+
similarity_mock.assert_called_once_with(input_text, expected_text, None, azure=True)

toolium/utils/ai_utils/openai.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,23 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35-
def openai_request(system_message, user_message, model_name=None, azure=False):
35+
def openai_request(system_message, user_message, model_name=None, azure=False, **kwargs):
3636
"""
3737
Make a request to OpenAI API (Azure or standard)
3838
3939
:param system_message: system message to set the behavior of the assistant
4040
:param user_message: user message with the request
41-
:param model: model to use
41+
:param model_name: name of the model to use
4242
:param azure: whether to use Azure OpenAI or standard OpenAI
43+
:param kwargs: additional parameters to be passed to the OpenAI client (azure_endpoint, timeout, etc.)
4344
:returns: response from OpenAI
4445
"""
4546
if OpenAI is None:
4647
raise ImportError("OpenAI is not installed. Please run 'pip install toolium[ai]' to use OpenAI features")
4748
config = DriverWrappersPool.get_default_wrapper().config
4849
model_name = model_name or config.get_optional('AI', 'openai_model', 'gpt-4o-mini')
4950
logger.info(f"Calling to OpenAI API with model {model_name}")
50-
client = AzureOpenAI() if azure else OpenAI()
51+
client = AzureOpenAI(**kwargs) if azure else OpenAI(**kwargs)
5152
completion = client.chat.completions.create(
5253
model=model_name,
5354
messages=[

toolium/utils/ai_utils/spacy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,18 @@
3131

3232

3333
@lru_cache(maxsize=8)
34-
def get_spacy_model(model_name):
34+
def get_spacy_model(model_name, **kwargs):
3535
"""
3636
get spaCy model.
3737
This method uses lru cache to get spaCy model to improve performance.
3838
3939
:param model_name: spaCy model name
40+
:param kwargs: additional parameters to be used by spaCy (disable, exclude, etc.)
4041
:return: spaCy model
4142
"""
4243
if spacy is None:
4344
return None
44-
return spacy.load(model_name)
45+
return spacy.load(model_name, **kwargs)
4546

4647

4748
def is_negator(tok):

toolium/utils/ai_utils/text_similarity.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@
3333
logger = logging.getLogger(__name__)
3434

3535

36-
def get_text_similarity_with_spacy(text, expected_text, model_name=None):
36+
def get_text_similarity_with_spacy(text, expected_text, model_name=None, **kwargs):
3737
"""
3838
Return similarity between two texts using spaCy.
3939
This method normalize both texts before comparing them.
4040
4141
:param text: string to compare
4242
:param expected_text: string with the expected text
4343
:param model_name: name of the spaCy model to use
44+
:param kwargs: additional parameters to be used by spaCy (disable, exclude, etc.)
4445
:returns: similarity score between the two texts
4546
"""
4647
# NOTE: spaCy similarity performance can be enhanced using some strategies like:
@@ -49,7 +50,7 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None):
4950
# - Preprocessing texts. Now we only preprocess negations.
5051
config = DriverWrappersPool.get_default_wrapper().config
5152
model_name = model_name or config.get_optional('AI', 'spacy_model', 'en_core_web_md')
52-
model = get_spacy_model(model_name)
53+
model = get_spacy_model(model_name, **kwargs)
5354
if model is None:
5455
raise ImportError("spaCy is not installed. Please run 'pip install toolium[ai]' to use spaCy features")
5556
text = model(preprocess_with_ud_negation(text, model))
@@ -59,35 +60,38 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None):
5960
return similarity
6061

6162

62-
def get_text_similarity_with_sentence_transformers(text, expected_text, model_name=None):
63+
def get_text_similarity_with_sentence_transformers(text, expected_text, model_name=None, **kwargs):
6364
"""
6465
Return similarity between two texts using Sentence Transformers
6566
6667
:param text: string to compare
6768
:param expected_text: string with the expected text
6869
:param model_name: name of the Sentence Transformers model to use
70+
:param kwargs: additional parameters to be used by SentenceTransformer (modules, device, prompts, etc.)
6971
:returns: similarity score between the two texts
7072
"""
7173
if SentenceTransformer is None:
7274
raise ImportError("Sentence Transformers is not installed. Please run 'pip install toolium[ai]'"
7375
" to use Sentence Transformers features")
7476
config = DriverWrappersPool.get_default_wrapper().config
7577
model_name = model_name or config.get_optional('AI', 'sentence_transformers_model', 'all-mpnet-base-v2')
76-
model = SentenceTransformer(model_name)
78+
model = SentenceTransformer(model_name, **kwargs)
7779
similarity = float(model.similarity(model.encode(expected_text), model.encode(text)))
7880
# similarity can be slightly > 1 due to float precision
7981
similarity = 1 if similarity > 1 else similarity
8082
logger.info(f"Sentence Transformers similarity: {similarity} between '{text}' and '{expected_text}'")
8183
return similarity
8284

8385

84-
def get_text_similarity_with_openai(text, expected_text, azure=False):
86+
def get_text_similarity_with_openai(text, expected_text, model_name=None, azure=False, **kwargs):
8587
"""
8688
Return semantic similarity between two texts using OpenAI LLM
8789
8890
:param text: string to compare
8991
:param expected_text: string with the expected text
92+
:param model_name: name of the OpenAI model to use
9093
:param azure: whether to use Azure OpenAI or standard OpenAI
94+
:param kwargs: additional parameters to be used by OpenAI client
9195
:returns: tuple with similarity score between the two texts and explanation
9296
"""
9397
system_message = (
@@ -102,7 +106,7 @@ def get_text_similarity_with_openai(text, expected_text, azure=False):
102106
f"The expected answer is: {expected_text}."
103107
f" The LLM answer is: {text}."
104108
)
105-
response = openai_request(system_message, user_message, azure=azure)
109+
response = openai_request(system_message, user_message, model_name, azure, **kwargs)
106110
try:
107111
response = json.loads(response)
108112
similarity = float(response['similarity'])
@@ -114,18 +118,20 @@ def get_text_similarity_with_openai(text, expected_text, azure=False):
114118
return similarity
115119

116120

117-
def get_text_similarity_with_azure_openai(text, expected_text):
121+
def get_text_similarity_with_azure_openai(text, expected_text, model_name=None, **kwargs):
118122
"""
119123
Return semantic similarity between two texts using Azure OpenAI LLM
120124
121125
:param text: string to compare
122126
:param expected_text: string with the expected text
127+
:param model_name: name of the Azure OpenAI model to use
128+
:param kwargs: additional parameters to be used by Azure OpenAI client
123129
:returns: tuple with similarity score between the two texts and explanation
124130
"""
125-
return get_text_similarity_with_openai(text, expected_text, azure=True)
131+
return get_text_similarity_with_openai(text, expected_text, model_name, azure=True, **kwargs)
126132

127133

128-
def assert_text_similarity(text, expected_texts, threshold, similarity_method=None):
134+
def assert_text_similarity(text, expected_texts, threshold, similarity_method=None, model_name=None, **kwargs):
129135
"""
130136
Get similarity between one text and a list of expected texts and assert if any of the expected texts is similar.
131137
@@ -134,14 +140,17 @@ def assert_text_similarity(text, expected_texts, threshold, similarity_method=No
134140
:param threshold: minimum similarity score to consider texts similar
135141
:param similarity_method: method to use for text comparison ('spacy', 'sentence_transformers', 'openai'
136142
or 'azure_openai')
143+
:param model_name: model name to use for the similarity method
144+
:param kwargs: additional parameters to be used by comparison methods
137145
"""
138146
config = DriverWrappersPool.get_default_wrapper().config
139147
similarity_method = similarity_method or config.get_optional('AI', 'text_similarity_method', 'spacy')
140148
expected_texts = [expected_texts] if isinstance(expected_texts, str) else expected_texts
141149
error_message = ""
142150
for expected_text in expected_texts:
143151
try:
144-
similarity = globals()[f'get_text_similarity_with_{similarity_method}'](text, expected_text)
152+
similarity = globals()[f'get_text_similarity_with_{similarity_method}'](text, expected_text,
153+
model_name, **kwargs)
145154
except KeyError:
146155
raise ValueError(f"Unknown similarity_method: '{similarity_method}', please use 'spacy',"
147156
f" 'sentence_transformers', 'openai' or 'azure_openai'")

0 commit comments

Comments
 (0)