-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
382 lines (303 loc) · 16.3 KB
/
main.py
File metadata and controls
382 lines (303 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# Standard Library
import os
import re
import time
import csv
import argparse
# Third-Party Library
from dotenv import load_dotenv
from openai import OpenAI
import urllib
# Local Imports
from Classes import ChatGPT, Prompts, WebScraper
from large_prompts.master_prompt import master_prompt
from re_functions.use_case_extractor import extract_use_cases
from model_apis.text_generation import claude_api, gemini_api, mistral_api
from model_apis.web_search import claude_search
# Load environment variables
load_dotenv()
# Constants
TOTAL_PAGE_CRAWLS = 4
def call_model_with_retry(model_name, model_type, formatted_prompt, web_scraper_obj, max_retries=4, retry_delay=10):
"""
Generic function to call any model with retry logic
"""
for attempt in range(max_retries):
try:
if model_type == "chatgpt":
chat_obj = ChatGPT(model_name, formatted_prompt, [], OpenAI(api_key=os.getenv("OpenAI_KEY"), max_retries=5))
response, input_tokens, output_tokens = chat_obj.chat_model()
elif model_type == "claude":
response, input_tokens, output_tokens = claude_api(model_name, formatted_prompt)
elif model_type == "deepseek":
chat_obj = ChatGPT(model_name, formatted_prompt, [], OpenAI(api_key=os.getenv("DEEPSEEK_KEY"), max_retries=5, base_url="https://api.deepseek.com"))
response, input_tokens, output_tokens = chat_obj.chat_model()
elif model_type == "gemini":
response, input_tokens, output_tokens = gemini_api(model_name, formatted_prompt)
elif model_type == "mistral":
response, input_tokens, output_tokens = mistral_api(model_name, formatted_prompt)
else:
raise ValueError(f"Unknown model type: {model_type}")
# Update token cost
web_scraper_obj.set_token_cost(input_tokens, output_tokens, model_name)
response += "\n\n########END OF USE CASE########\n\n"
return response
except Exception as e:
print(f"Error in {model_type.upper()} API call. Attempt {attempt + 1}: {e}")
if attempt < max_retries - 1:
time.sleep(retry_delay)
else:
web_scraper_obj.set_token_cost(input_tokens=0, output_tokens=0, model_name=model_name)
raise RuntimeError(f"{model_type.upper()} Classification failed") from e
def use_case_separator(all_use_cases):
"""Splits use cases generated by the 'claude_search' function into a dictionary of use cases.\n
Uses regex to split on the pattern: '\\n######'\n
Returns: A dictionary with use case names as keys and their descriptions as values."""
# Find all headings that mark new use cases (e.g., "###### Use Case Name")
sections = re.split(r'\n###### (.+)', all_use_cases)
# Organize the extracted parts into a dictionary
use_cases = {}
for i in range(1, len(sections), 2):
name = sections[i].strip()
content = sections[i + 1].strip()
use_cases[name] = content
# for name, content in use_cases.items():
# print(f"--- Use Case: {name} ---\n{content}\n\n\n\n\n\n")
return use_cases
def run_search_workflow(input_file, output_file):
"""
Process URLs from a CSV file using claude_search and output results to a new CSV file
"""
web_search_model = "claude-sonnet-4-20250514"
try:
# Initialize WebScraper for token cost tracking
web_scraper_obj = WebScraper()
# Initialize output CSV file
with open(output_file, mode="w", newline="", encoding="utf-8-sig") as f:
writer = csv.writer(f)
writer.writerow(["Company Name", "Use Case Name", "Use Case Description"])
# Read input CSV file
with open(input_file, mode="r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row_idx, row in enumerate(reader):
company_name = row.get("Company Name", "")
url = row.get("URLs", "")
# If no company name provided, try to extract from URL
if not company_name and url:
parsed_url = urllib.parse.urlparse(url)
domain = parsed_url.netloc
domain_parts = domain.split('.')
if len(domain_parts) >= 2:
company_name = domain_parts[-2].capitalize()
if not url:
continue
# Extract domain from URL
parsed_url = urllib.parse.urlparse(url)
domain = parsed_url.netloc
domain_parts = domain.split('.')
if len(domain_parts) >= 2:
domain = '.'.join(domain_parts[-2:])
if domain:
print(f"Processing URL {row_idx + 1}: {domain} (Company: {company_name})")
try:
# Perform search
search_result = claude_search(web_search_model, domain)
# Use the use_case_separator function to break down the use cases
use_cases_dict = use_case_separator(search_result)
print(f"Found {len(use_cases_dict)} use cases for {company_name}")
# Append results to CSV file
with open(output_file, mode="a", newline="", encoding="utf-8-sig") as f:
writer = csv.writer(f)
if use_cases_dict:
for use_case_name, use_case_description in use_cases_dict.items():
writer.writerow([company_name, use_case_name, use_case_description])
else:
# If no use cases found, create a single row with an error message
writer.writerow([company_name, "No structured use cases found", "No use cases found in the search result."])
print(f"Successfully processed: {url}")
except Exception as e:
print(f"Error processing {url}: {str(e)}")
# Append error row to CSV
with open(output_file, mode="a", newline="", encoding="utf-8-sig") as f:
writer = csv.writer(f)
writer.writerow([company_name, "Error", f"Error: {str(e)}"])
print(f"Search results saved to CSV: {output_file}")
# print(f"Total token cost: ${web_scraper_obj.get_token_cost():.4f}")
except Exception as e:
print(f"Error in search workflow: {str(e)}")
raise
def classify_from_csv(input_csv, output_csv, models):
"""
Read use cases from a CSV file and classify them one by one, saving results to CSV
"""
# Model configurations with hardcoded model names
all_model_configs = {
"chatgpt": ("chatgpt-4o-latest", "chatgpt", "ChatGPT 4o"),
"claude": ("claude-sonnet-4-20250514", "claude", "Claude Sonnet 4"),
"deepseek": ("deepseek-reasoner", "deepseek", "DeepSeek Reasoner"),
"gemini": ("gemini-2.0-flash-thinking-exp-01-21", "gemini", "Gemini 2.0 Flash Thinker"),
"mistral": ("mistral-large-latest", "mistral", "Mistral Large")
}
# Filter model configurations based on user selection
model_configs = [all_model_configs[model] for model in models if model in all_model_configs]
if not model_configs:
raise ValueError("No valid models specified")
print(f"Using {len(model_configs)} models: {[config[2] for config in model_configs]}")
# Allowed categories
allowed_categories = [
'Prohibited AI system',
'High-risk AI system under Annex I',
'High-risk AI system under Annex III',
'High-risk AI system with transparency obligations',
'System with transparency obligations',
'Low-risk AI system',
'Uncertain'
]
# Initialize the objects
web_scraper_obj = WebScraper()
prompts_obj = Prompts(TOTAL_PAGE_CRAWLS)
MAX_API_TRIES = 4
retry_delay = 10
# Initialize output CSV file
with open(output_csv, mode="w", newline="", encoding="utf-8-sig") as f:
writer = csv.writer(f)
writer.writerow(["Company Name", "Use Case Name", "Use Case Description", "Risk Classification", "Reason", "Model Distribution", "Chosen Model", "Token Cost ($)"])
# Read input CSV file
with open(input_csv, mode="r", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row_idx, row in enumerate(reader):
company_name = row.get("Company Name", "")
use_case_name = row.get("Use Case Name", "")
use_case_description = row.get("Use Case Description", "")
print(f"Processing row {row_idx + 1}: {company_name} - {use_case_name}")
# Create use case object for processing
use_case = {
"use_case_name": use_case_name,
"use_case_description": use_case_description
}
use_case_string = f"AI Use Case: {use_case_name}\nUse Case Description: {use_case_description}"
# Prepare the prompt with the use case
formatted_prompt = prompts_obj.prepare_AI_Act_prompt(master_prompt, use_case_string)
# Call all models and collect responses
model_responses = []
for model_name, model_type, display_name in model_configs:
try:
response = call_model_with_retry(model_name, model_type, formatted_prompt, web_scraper_obj, MAX_API_TRIES, retry_delay)
model_responses.append(response)
except Exception as e:
print(f"Error with {display_name}: {str(e)}")
model_responses.append(f"Error: {str(e)}")
# Combine all responses
final_string = "".join(model_responses)
# Get the combined json from all models
result_json = extract_use_cases(use_case, final_string)
# Create a dictionary to store the votes and store individual classifications from all models
voters = [config[2] for config in model_configs]
votings = {}
classifications_list = []
# Iterate through the result JSON and count votes for each classification
for model_use_case in result_json:
classification = model_use_case["Risk Classification"]
if classification in allowed_categories:
if classification not in votings:
votings[classification] = 0
votings[classification] += 1
classifications_list.append(classification)
else:
# If the classification is not in the allowed categories, re-classify it as "Uncertain"
model_use_case["Risk Classification"] = "Uncertain"
classifications_list.append("Uncertain")
if "Uncertain" not in votings:
votings["Uncertain"] = 0
votings["Uncertain"] += 1
# Find the classification with the most votes
max_votes = max(votings.values())
classifications_with_max_votes = [classification for classification, votes in votings.items() if votes == max_votes]
# If tie, pick least risky from the tie group
if len(classifications_with_max_votes) > 1:
classifications_with_max_votes.sort(
key=lambda x: allowed_categories.index(x) if x in allowed_categories else -1,
reverse=True
)
final_classification = classifications_with_max_votes[0]
# print(f"Final Classification: {final_classification}")
# Store the vote distribution from each model
model_distribution = dict(zip(voters, classifications_list))
model_distribution_string = ""
for model, classification in model_distribution.items():
model_distribution_string += f"{model}: {classification}\n"
# Get the use cases with the final classification
filtered_use_cases = [model_use_case for model_use_case in result_json if model_use_case["Risk Classification"] == final_classification]
# Get the use case with the longest reason
longest_reasoned_use_case = max(filtered_use_cases, key=lambda x: len(x["Reason"]))
longest_reasoned_use_case_index = result_json.index(longest_reasoned_use_case)
chosen_model = voters[longest_reasoned_use_case_index]
# Get token cost for this classification
token_cost = web_scraper_obj.get_token_cost()
# Write result to CSV immediately
with open(output_csv, mode="a", newline="", encoding="utf-8-sig") as f:
writer = csv.writer(f)
writer.writerow([
company_name,
use_case_name,
use_case_description,
final_classification,
longest_reasoned_use_case["Reason"],
model_distribution_string.strip(),
chosen_model,
token_cost
])
# Reset token cost for next iteration
web_scraper_obj.reset_token_cost()
print(f"Classification complete. Results saved to: {output_csv}")
def main():
parser = argparse.ArgumentParser(
description="EU AI Act Risk Classification Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python main.py search -i companies.csv -o use_cases.csv
python main.py classify -i use_cases.csv -o classifications.csv
python main.py classify -i use_cases.csv -o classifications.csv -m chatgpt claude
Input CSV Formats:
Search: Company Name, URLs (or URL)
Classify: Company Name, Use Case Name, Use Case Description
Available Models: chatgpt, claude, deepseek, gemini, mistral
""")
parser.add_argument("command", choices=["search", "classify"],
help="Command to run: 'search' for URL searching or 'classify' for risk classification")
# Search and classify command arguments
parser.add_argument("--input-file", "-i", type=str, required=True,
help="Input CSV file path")
parser.add_argument("--output-file", "-o", type=str, required=True,
help="Output CSV file path")
# Classification command arguments
parser.add_argument("--models", "-m", type=str, nargs='+',
choices=["chatgpt", "claude", "deepseek", "gemini", "mistral"],
default=["claude"],
metavar="MODEL",
help="Select one or more models for classification (default: claude)")
args = parser.parse_args()
if args.command == "search":
print(f"Running search workflow...")
print(f"Input CSV: {args.input_file}")
print(f"Output CSV: {args.output_file}")
run_search_workflow(args.input_file, args.output_file)
elif args.command == "classify":
print(f"Running classification workflow...")
print(f"Input CSV: {args.input_file}")
print(f"Output CSV: {args.output_file}")
print(f"Using models: {', '.join(args.models)}")
classify_from_csv(
input_csv=args.input_file,
output_csv=args.output_file,
models=args.models
)
if __name__ == "__main__":
main()
# prompt = "Hi there!"
# response, input_tokens, output_tokens = claude_api("claude-3-5-haiku-latest", prompt)
# print(f"Claude Response: {response}")
# print(f"Input tokens: {input_tokens}, Output tokens: {output_tokens}")
# web_search = claude_search("claude-3-5-haiku-latest", "openai.com")
# print(f"Web Search Response: {web_search}")