1616limitations under the License.
1717"""
1818import logging
19-
20- try :
21- from sentence_transformers import SentenceTransformer , util
22- except ImportError :
23- SentenceTransformer = None
19+ import json
2420
2521from toolium .utils .ai_utils .openai import openai_request
26- from toolium .driver_wrappers_pool import DriverWrappersPool
2722
2823# Configure logger
2924logger = logging .getLogger (__name__ )
@@ -50,11 +45,6 @@ def build_system_message(characteristics):
5045 - "name": the exact characteristic name as listed above.
5146 - "score": a float between 0.0 and 0.2.
5247 3) Compute an overall score "overall_match" between 0.0 and 1.0 that summarizes how well the text matches the whole set. It does not have to be a simple arithmetic mean, but must still be in [0.0, 1.0].
53- 4) Produce a "data" object that can contain extra structured analysis sections:
54- - "data" MUST always be present.
55- - "data" MUST be a JSON object.
56- - Each key in "data" is the title/name of a section (e.g. "genres", "entities", "style_breakdown").
57- - Each value is a JSON array (the structure of its objects will be defined by additional system instructions).
5848
5949 Output format (IMPORTANT):
6050 Return ONLY a single valid JSON object with this exact top-level structure and property names:
@@ -66,17 +56,11 @@ def build_system_message(characteristics):
6656 "name": string,
6757 "score": float
6858 }}
69- ],
70- "data": {{
71- "<section_title>": [
72- {{}}
73- ]
74- }}
59+ ]
7560 }}
7661
7762 Constraints:
7863 - Do NOT include scores for high valued (<=0.2) features at features list.
79- - The "data" field must ALWAYS be present. If there are no extra sections, it MUST be: "data": {{}}.
8064 - Use a dot as decimal separator (e.g. 0.75, not 0,75).
8165 - Use at most 2 decimal places for all scores.
8266 - Do NOT include any text outside the JSON (no Markdown, no comments, no explanations).
@@ -85,83 +69,43 @@ def build_system_message(characteristics):
8569 return base_prompt .strip ()
8670
8771
88- def get_text_criteria_analysis_openai (text_input , target_features , extra_tasks = None , model_name = None ,
89- azure = False , ** kwargs ):
72+ def get_text_criteria_analysis (text_input , text_criteria , model_name = None , azure = False , ** kwargs ):
9073 """
9174 Get text criteria analysis using Azure OpenAI. To analyze how well a given text
9275 matches a set of target characteristics.
9376 The response is a structured JSON object with overall match score, individual feature scores,
9477 and additional data sections.
9578
9679 :param text_input: text to analyze
97- :param target_features : list of target characteristics to evaluate
80+ :param text_criteria : list of target characteristics to evaluate
9881 :param extra_tasks: additional system messages for extra analysis sections (optional)
9982 :param model_name: name of the Azure OpenAI model to use
10083 :param azure: whether to use Azure OpenAI or standard OpenAI
10184 :param kwargs: additional parameters to be used by Azure OpenAI client
10285 :returns: response from Azure OpenAI
10386 """
10487 # Build prompt using base prompt and target features
105- system_message = build_system_message (target_features )
88+ system_message = build_system_message (text_criteria )
10689 msg = [system_message ]
107- if extra_tasks :
108- if isinstance (extra_tasks , list ):
109- for task in extra_tasks :
110- msg .append (task )
111- else :
112- msg .append (extra_tasks )
11390 return openai_request (msg , text_input , model_name , azure , ** kwargs )
11491
11592
116- def get_text_criteria_analysis_sentence_transformers (text_input , target_features , extra_tasks = None ,
117- model_name = None , azure = True , ** kwargs ):
93+ def assert_text_criteria (text_input , text_criteria , threshold , model_name = None , azure = False , ** kwargs ):
11894 """
119- Get text criteria analysis using Sentence Transformers. Sentence Transformers works better using examples
120- that are semantically similar, so this method is more suitable for evaluating characteristics like
121- "is a greeting phrase", "talks about the weather", etc.
122- The response is a structured JSON object with overall match score, individual feature scores,
123- and additional data sections.
95+ Get text criteria analysis and assert if overall match score is above threshold.
12496
12597 :param text_input: text to analyze
126- :param target_features: list of target characteristics to evaluate
127- :param extra_tasks: additional system messages for extra analysis sections (not used here, for compatibility)
128- :param model_name: name of the Sentence Transformers model to use
129- :param azure: whether to use Azure OpenAI or standard OpenAI (not used here, for compatibility)
130- :param kwargs: additional parameters to be used by Sentence Transformers client
98+ :param text_criteria: list of target characteristics to evaluate
99+ :param threshold: minimum overall match score to consider the text acceptable
100+ :param model_name: name of the Azure OpenAI model to use
101+ :param azure: whether to use Azure OpenAI or standard OpenAI
102+ :param kwargs: additional parameters to be used by Azure OpenAI client
103+ :raises AssertionError: if overall match score is below threshold
131104 """
132- if SentenceTransformer is None :
133- raise ImportError ("Sentence Transformers is not installed. Please run 'pip install toolium[ai]'"
134- " to use Sentence Transformers features" )
135-
136- def similarity_to_score (cos_sim ):
137- if cos_sim <= 0.1 :
138- return 0.0
139- return cos_sim / 0.7
140-
141- config = DriverWrappersPool .get_default_wrapper ().config
142- model_name = model_name or config .get_optional ('AI' , 'sentence_transformers_model' , 'all-mpnet-base-v2' )
143- model = SentenceTransformer (model_name , ** kwargs )
144- # Pre-compute feature embeddings
145- feature_embs = model .encode ([f for f in target_features ], normalize_embeddings = True )
146- # text_input embedding
147- text_emb = model .encode (text_input , normalize_embeddings = True )
148- # Computes cosine-similarities between the text and features tensors (range [-1, 1])
149- sims = util .cos_sim (text_emb , feature_embs )[0 ].tolist ()
150- results = []
151- # Generate contracted results
152- for f , sim in zip (target_features , sims ):
153- # Normalize similarity from [-1, 1] to [0, 1]
154- score = similarity_to_score (sim )
155- results .append ({
156- "name" : f ,
157- "score" : round (score , 2 )
158- })
159-
160- # overall score as average of feature scores
161- overall = sum (r ["score" ] for r in results ) / len (results )
162-
163- return {
164- "overall_match" : round (overall , 2 ),
165- "features" : results ,
166- "data" : {}
167- }
105+ analysis = json .loads (get_text_criteria_analysis (text_input , text_criteria , model_name , azure , ** kwargs ))
106+ overall_match = analysis .get ("overall_match" , 0.0 )
107+ if overall_match < threshold :
108+ raise AssertionError (f"Text criteria analysis failed: overall match { overall_match } "
109+ f"is below threshold { threshold } " )
110+ logger .info (f"Text criteria analysis passed: overall match { overall_match } "
111+ f"is above threshold { threshold } " )
0 commit comments