3333logger = logging .getLogger (__name__ )
3434
3535
36- def get_text_similarity_with_spacy (text , expected_text , model_name = None ):
36+ def get_text_similarity_with_spacy (text , expected_text , model_name = None , ** kwargs ):
3737 """
3838 Return similarity between two texts using spaCy.
3939 This method normalize both texts before comparing them.
4040
4141 :param text: string to compare
4242 :param expected_text: string with the expected text
4343 :param model_name: name of the spaCy model to use
44+ :param kwargs: additional parameters to be used by spaCy (disable, exclude, etc.)
4445 :returns: similarity score between the two texts
4546 """
4647 # NOTE: spaCy similarity performance can be enhanced using some strategies like:
@@ -49,7 +50,7 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None):
4950 # - Preprocessing texts. Now we only preprocess negations.
5051 config = DriverWrappersPool .get_default_wrapper ().config
5152 model_name = model_name or config .get_optional ('AI' , 'spacy_model' , 'en_core_web_md' )
52- model = get_spacy_model (model_name )
53+ model = get_spacy_model (model_name , ** kwargs )
5354 if model is None :
5455 raise ImportError ("spaCy is not installed. Please run 'pip install toolium[ai]' to use spaCy features" )
5556 text = model (preprocess_with_ud_negation (text , model ))
@@ -59,35 +60,38 @@ def get_text_similarity_with_spacy(text, expected_text, model_name=None):
5960 return similarity
6061
6162
62- def get_text_similarity_with_sentence_transformers (text , expected_text , model_name = None ):
63+ def get_text_similarity_with_sentence_transformers (text , expected_text , model_name = None , ** kwargs ):
6364 """
6465 Return similarity between two texts using Sentence Transformers
6566
6667 :param text: string to compare
6768 :param expected_text: string with the expected text
6869 :param model_name: name of the Sentence Transformers model to use
70+ :param kwargs: additional parameters to be used by SentenceTransformer (modules, device, prompts, etc.)
6971 :returns: similarity score between the two texts
7072 """
7173 if SentenceTransformer is None :
7274 raise ImportError ("Sentence Transformers is not installed. Please run 'pip install toolium[ai]'"
7375 " to use Sentence Transformers features" )
7476 config = DriverWrappersPool .get_default_wrapper ().config
7577 model_name = model_name or config .get_optional ('AI' , 'sentence_transformers_model' , 'all-mpnet-base-v2' )
76- model = SentenceTransformer (model_name )
78+ model = SentenceTransformer (model_name , ** kwargs )
7779 similarity = float (model .similarity (model .encode (expected_text ), model .encode (text )))
7880 # similarity can be slightly > 1 due to float precision
7981 similarity = 1 if similarity > 1 else similarity
8082 logger .info (f"Sentence Transformers similarity: { similarity } between '{ text } ' and '{ expected_text } '" )
8183 return similarity
8284
8385
84- def get_text_similarity_with_openai (text , expected_text , azure = False ):
86+ def get_text_similarity_with_openai (text , expected_text , model_name = None , azure = False , ** kwargs ):
8587 """
8688 Return semantic similarity between two texts using OpenAI LLM
8789
8890 :param text: string to compare
8991 :param expected_text: string with the expected text
92+ :param model_name: name of the OpenAI model to use
9093 :param azure: whether to use Azure OpenAI or standard OpenAI
94+ :param kwargs: additional parameters to be used by OpenAI client
9195 :returns: tuple with similarity score between the two texts and explanation
9296 """
9397 system_message = (
@@ -102,7 +106,7 @@ def get_text_similarity_with_openai(text, expected_text, azure=False):
102106 f"The expected answer is: { expected_text } ."
103107 f" The LLM answer is: { text } ."
104108 )
105- response = openai_request (system_message , user_message , azure = azure )
109+ response = openai_request (system_message , user_message , model_name , azure , ** kwargs )
106110 try :
107111 response = json .loads (response )
108112 similarity = float (response ['similarity' ])
@@ -114,18 +118,20 @@ def get_text_similarity_with_openai(text, expected_text, azure=False):
114118 return similarity
115119
116120
117- def get_text_similarity_with_azure_openai (text , expected_text ):
121+ def get_text_similarity_with_azure_openai (text , expected_text , model_name = None , ** kwargs ):
118122 """
119123 Return semantic similarity between two texts using Azure OpenAI LLM
120124
121125 :param text: string to compare
122126 :param expected_text: string with the expected text
127+ :param model_name: name of the Azure OpenAI model to use
128+ :param kwargs: additional parameters to be used by Azure OpenAI client
123129 :returns: tuple with similarity score between the two texts and explanation
124130 """
125- return get_text_similarity_with_openai (text , expected_text , azure = True )
131+ return get_text_similarity_with_openai (text , expected_text , model_name , azure = True , ** kwargs )
126132
127133
128- def assert_text_similarity (text , expected_texts , threshold , similarity_method = None ):
134+ def assert_text_similarity (text , expected_texts , threshold , similarity_method = None , model_name = None , ** kwargs ):
129135 """
130136 Get similarity between one text and a list of expected texts and assert if any of the expected texts is similar.
131137
@@ -134,14 +140,17 @@ def assert_text_similarity(text, expected_texts, threshold, similarity_method=No
134140 :param threshold: minimum similarity score to consider texts similar
135141 :param similarity_method: method to use for text comparison ('spacy', 'sentence_transformers', 'openai'
136142 or 'azure_openai')
143+ :param model_name: model name to use for the similarity method
144+ :param kwargs: additional parameters to be used by comparison methods
137145 """
138146 config = DriverWrappersPool .get_default_wrapper ().config
139147 similarity_method = similarity_method or config .get_optional ('AI' , 'text_similarity_method' , 'spacy' )
140148 expected_texts = [expected_texts ] if isinstance (expected_texts , str ) else expected_texts
141149 error_message = ""
142150 for expected_text in expected_texts :
143151 try :
144- similarity = globals ()[f'get_text_similarity_with_{ similarity_method } ' ](text , expected_text )
152+ similarity = globals ()[f'get_text_similarity_with_{ similarity_method } ' ](text , expected_text ,
153+ model_name , ** kwargs )
145154 except KeyError :
146155 raise ValueError (f"Unknown similarity_method: '{ similarity_method } ', please use 'spacy',"
147156 f" 'sentence_transformers', 'openai' or 'azure_openai'" )
0 commit comments