-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathai_reprocess_clean.py
More file actions
318 lines (271 loc) · 12.1 KB
/
ai_reprocess_clean.py
File metadata and controls
318 lines (271 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import asyncio
import json
import aiohttp
from concept_graph import build_graph
async def ai_reprocess_nodes(note_text, current_nodes, analysis_type='bridges', ai_provider=None,
api_key=None, ai_model=None, host=None, port=None):
"""Use AI to filter and select only the most important nodes from current graph."""
if not ai_provider or not current_nodes:
return current_nodes
# Validate parameters based on provider
if ai_provider in ['openai', 'openrouter', 'google', 'groq'] and not api_key:
return current_nodes
if ai_provider in ['lmstudio', 'ollama'] and (not host or not port):
return current_nodes
try:
# Extract node labels from current nodes
node_labels = [node['label'] if isinstance(node, dict) else str(node) for node in current_nodes]
# Create AI prompt for node filtering
prompt = f"""
You are an expert in semantic analysis and concept mapping. Given the following text and list of extracted concepts,
please select only the 8-12 most semantically important and meaningful concepts that best represent the core ideas.
TEXT CONTENT:
{note_text[:1500]}...
CURRENT CONCEPTS:
{', '.join(node_labels)}
ANALYSIS TYPE: {analysis_type}
{"(Focus on concepts that bridge different topics)" if analysis_type == 'bridges' else ""}
{"(Focus on central hub concepts)" if analysis_type == 'hubs' else ""}
TASK:
1. Remove generic words, pronouns, common verbs, and non-meaningful terms
2. Select 8-12 concepts that are most semantically significant
3. Prioritize domain-specific terms, proper nouns, and key concepts
4. Ensure selected concepts represent the main themes of the text
RULES:
- Remove: pronouns (you, his, its, they, etc.), generic words (thing, way, time), common verbs (get, make, have)
- Keep: technical terms, proper nouns, domain-specific concepts, key themes
- Aim for 8-12 final concepts maximum
- Focus on concepts that add semantic value
Respond with ONLY a JSON array of the selected concept labels:
["concept1", "concept2", "concept3", ...]
"""
# Configure API call based on provider
filtered_concepts = []
if ai_provider.lower() == 'openai':
filtered_concepts = await _call_openai_api(prompt, api_key, ai_model)
elif ai_provider.lower() == 'openrouter':
filtered_concepts = await _call_openrouter_api(prompt, api_key, ai_model)
elif ai_provider.lower() == 'google':
filtered_concepts = await _call_google_api(prompt, api_key, ai_model)
elif ai_provider.lower() == 'groq':
filtered_concepts = await _call_groq_api(prompt, api_key, ai_model)
elif ai_provider.lower() == 'lmstudio':
filtered_concepts = await _call_lmstudio_api(prompt, ai_model, host, port)
elif ai_provider.lower() == 'ollama':
filtered_concepts = await _call_ollama_api(prompt, ai_model, host, port)
else:
return current_nodes
# Filter current nodes to only include AI-selected concepts
if filtered_concepts:
filtered_nodes = []
for node in current_nodes:
node_label = node['label'] if isinstance(node, dict) else str(node)
if any(concept.lower() in node_label.lower() or node_label.lower() in concept.lower()
for concept in filtered_concepts):
filtered_nodes.append(node)
# Ensure we have at least a few nodes
if len(filtered_nodes) < 3 and len(current_nodes) >= 3:
# Fallback: return top nodes by importance if available
sorted_nodes = sorted(current_nodes,
key=lambda x: x.get('importance', 0) if isinstance(x, dict) else 0,
reverse=True)
return sorted_nodes[:8]
return filtered_nodes[:12] # Limit to 12 nodes max
return current_nodes
except Exception as e:
print(f"AI reprocessing error: {str(e)}")
return current_nodes
async def _call_openai_api(prompt, api_key, model=None):
"""Call OpenAI API for concept filtering."""
if not model:
model = "gpt-4o-mini"
url = "https://api.openai.com/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 500,
"temperature": 0.2
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
content = result["choices"][0]["message"]["content"]
return _parse_concepts_response(content)
return []
async def _call_openrouter_api(prompt, api_key, model=None):
"""Call OpenRouter API for concept filtering."""
if not model:
model = "meta-llama/llama-3.1-8b-instruct:free"
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"HTTP-Referer": "https://whispad.local",
"X-Title": "WhisPad AI"
}
data = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 500,
"temperature": 0.2
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
content = result["choices"][0]["message"]["content"]
return _parse_concepts_response(content)
return []
async def _call_google_api(prompt, api_key, model=None):
"""Call Google Gemini API for concept filtering."""
if not model:
model = "gemini-2.0-flash"
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
headers = {
"Content-Type": "application/json"
}
data = {
"contents": [{
"parts": [{"text": prompt}]
}],
"generationConfig": {
"temperature": 0.2,
"maxOutputTokens": 500
}
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
content = result["candidates"][0]["content"]["parts"][0]["text"]
return _parse_concepts_response(content)
return []
async def _call_groq_api(prompt, api_key, model=None):
"""Call Groq API for concept filtering."""
if not model:
model = "llama-3.1-70b-versatile"
url = "https://api.groq.com/openai/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 500,
"temperature": 0.2
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
content = result["choices"][0]["message"]["content"]
return _parse_concepts_response(content)
return []
async def _call_lmstudio_api(prompt, model, host, port):
"""Call LM Studio API for concept filtering."""
url = f"http://{host}:{port}/v1/chat/completions"
headers = {
"Content-Type": "application/json"
}
data = {
"model": model or "local-model",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 500,
"temperature": 0.2
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
content = result["choices"][0]["message"]["content"]
return _parse_concepts_response(content)
return []
async def _call_ollama_api(prompt, model, host, port):
"""Call Ollama API for concept filtering."""
url = f"http://{host}:{port}/api/chat"
headers = {
"Content-Type": "application/json"
}
data = {
"model": model or "llama3",
"messages": [{"role": "user", "content": prompt}],
"stream": False
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
if response.status == 200:
result = await response.json()
content = result["message"]["content"]
return _parse_concepts_response(content)
return []
def _parse_concepts_response(content):
"""Parse AI response and extract concept list."""
try:
# Try to find JSON array in the response
import re
# Look for JSON array pattern
json_match = re.search(r'\[.*?\]', content, re.DOTALL)
if json_match:
json_str = json_match.group(0)
concepts = json.loads(json_str)
if isinstance(concepts, list):
return [str(concept).strip() for concept in concepts if concept]
# Fallback: try to parse as JSON directly
concepts = json.loads(content)
if isinstance(concepts, list):
return [str(concept).strip() for concept in concepts if concept]
except (json.JSONDecodeError, AttributeError):
# Fallback: extract concepts from text manually
lines = content.strip().split('\n')
concepts = []
for line in lines:
line = line.strip()
if line.startswith('"') and line.endswith('"'):
concepts.append(line[1:-1])
elif line.startswith('- '):
concepts.append(line[2:])
return concepts[:12]
return []
def build_graph_with_selected_nodes(note_text, selected_concepts, analysis_type='bridges'):
"""Build a new concept graph using only the AI-selected concepts."""
if not selected_concepts:
# Fallback to original build_graph
return build_graph(note_text, analysis_type, max_terms=60)
# Use the original build_graph function but filter results
full_result = build_graph(note_text, analysis_type, max_terms=60)
if not full_result or 'nodes' not in full_result:
return full_result
# Filter nodes to include only selected concepts
filtered_nodes = []
for node in full_result['nodes']:
node_label = node.get('label', '').lower()
if any(concept.lower() in node_label or node_label in concept.lower()
for concept in selected_concepts):
filtered_nodes.append(node)
# Filter edges to only include edges between filtered nodes
filtered_node_ids = {node['id'] for node in filtered_nodes}
filtered_links = []
for link in full_result.get('links', []):
if link['source'] in filtered_node_ids and link['target'] in filtered_node_ids:
filtered_links.append(link)
# Update the result
result = full_result.copy()
result['graph'] = {
'nodes': filtered_nodes,
'links': filtered_links
}
# Update insights to reflect the filtered graph
if 'insights' in result:
result['insights']['total_nodes'] = len(filtered_nodes)
result['insights']['total_edges'] = len(filtered_links)
if len(filtered_nodes) > 1:
result['insights']['density'] = len(filtered_links) / (len(filtered_nodes) * (len(filtered_nodes) - 1) / 2)
else:
result['insights']['density'] = 0
return result