Skip to content

Commit ba9ec6d

Browse files
committed
roll back to json for llm
1 parent fae8622 commit ba9ec6d

File tree

1 file changed

+54
-67
lines changed

1 file changed

+54
-67
lines changed

psyflow/LLM.py

Lines changed: 54 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from importlib import resources
1010
import yaml
1111
from psyflow import load_config
12-
1312
# --- Custom Exception for LLM API Errors ---
1413
class LLMAPIError(Exception):
1514
"""
@@ -246,40 +245,28 @@ def _filter_openai_kwargs(params: Dict[str, Any]) -> Dict[str, Any]:
246245
valid = {"temperature","max_tokens","top_p","stop","presence_penalty","frequency_penalty","n","logit_bias","stream"}
247246
return {k: v for k, v in params.items() if k in valid}
248247

249-
@staticmethod
250-
def _fence_content(path: str, content: str) -> str:
251-
"""
252-
Wrap content in a markdown fence based on file extension.
253-
"""
254-
ext = os.path.splitext(path)[1].lower()
255-
if ext in ('.py', '.js', '.ts'):
256-
lang = 'python'
257-
elif ext in ('.yaml', '.yml'):
258-
lang = 'yaml'
259-
elif ext in ('.md', '.markdown'):
260-
# already markdown—no fence
261-
return content
262-
else:
263-
lang = ''
264-
if lang:
265-
return f"```{lang}\n{content}\n```"
266-
return content
267248

249+
@staticmethod
268250
def _parse_entry(
269-
self,
270251
entry: Dict[str, Union[str, List[str]]]
271252
) -> Dict[str, str]:
272253
"""
273-
Parse one dict of {key → file(s)/URL(s)/text} into {key → fenced markdown/text}.
254+
Parse one dict of {key → file(s)/URL(s)/text} into {key → combined text}.
255+
256+
:param entry:
257+
A mapping where each value is either:
258+
- a list of local file paths or HTTP URLs
259+
- a raw text string
260+
:return:
261+
A dict mapping each key to the concatenated text contents.
274262
"""
275263
out: Dict[str, str] = {}
276-
277264
def _load(loc: str) -> Optional[str]:
278-
# Local file
265+
# Local file?
279266
if os.path.isfile(loc):
280267
with open(loc, 'r', encoding='utf-8') as f:
281268
return f.read()
282-
# URL
269+
# URL?
283270
parsed = urlparse(loc)
284271
if parsed.scheme in ("http", "https"):
285272
resp = requests.get(loc, timeout=10)
@@ -289,76 +276,76 @@ def _load(loc: str) -> Optional[str]:
289276

290277
for key, val in entry.items():
291278
if isinstance(val, list):
292-
parts: List[str] = []
279+
chunks = []
293280
for loc in val:
294-
txt = _load(loc) or ''
281+
txt = _load(loc)
295282
if txt:
296-
parts.append(self._fence_content(loc, txt))
297-
if parts:
298-
out[key] = "\n\n".join(parts)
283+
chunks.append(txt)
284+
if chunks:
285+
out[key] = "\n\n".join(chunks)
299286

300287
elif isinstance(val, str):
301288
if val.startswith(("http://", "https://")):
302-
txt = _load(val) or ''
289+
txt = _load(val)
303290
if txt:
304291
out[key] = txt
305-
elif os.path.isfile(val):
306-
txt = _load(val) or ''
307-
if txt:
308-
out[key] = self._fence_content(val, txt)
309292
else:
310-
# raw text
311293
out[key] = val.strip()
312294

313295
return out
314-
296+
297+
315298
def add_knowledge(
316-
self,
317-
source: Union[
318-
str, # path to JSON or MD file
319-
List[Dict[str, Union[str, List[str]]]] # in-memory entries
320-
]
321-
) -> None:
322-
"""
323-
Bulk-load few-shot examples:
324-
• If `source` is a .json -> load list of dicts (no parsing)
325-
• If `source` is a .md -> load single entry under key 'markdown'
326-
• If `source` is a list -> parse each via _parse_entry()
299+
self,
300+
source: Union[
301+
str, # path to JSON file
302+
List[Dict[str, Union[str, List[str]]]] # in-memory entries
303+
]
304+
) -> None:
305+
"""
306+
Bulk-load few-shot examples into memory from either:
307+
308+
1. A JSON file path containing a list of example-dicts, or
309+
2. A list of example-dicts directly.
310+
311+
Each example-dict maps keys to either:
312+
• List[str] of file paths or URLs → will be parsed via `_parse_entry()`
313+
• Raw text (str) → will be stripped and stored as-is
314+
315+
:param source:
316+
- If `str`, treated as path to a JSON file containing List[Dict[...]].
317+
- If `list`, treated as in-memory list of example dicts.
318+
:raises ValueError: on invalid JSON structure or unsupported source type.
327319
"""
328320
if isinstance(source, str):
329-
ext = os.path.splitext(source)[1].lower()
330-
if ext == '.json':
331-
with open(source, 'r', encoding='utf-8') as f:
332-
data = json.load(f)
333-
if not isinstance(data, list):
334-
raise ValueError("JSON must contain a list of example-dicts")
335-
for ex in data:
336-
if isinstance(ex, dict):
337-
self.knowledge_base.append(ex)
338-
elif ext in ('.md', '.markdown'):
339-
# treat entire markdown file as one example
340-
entry = self._parse_entry({'markdown': source})
341-
if entry:
342-
self.knowledge_base.append(entry)
343-
else:
344-
raise ValueError("Unsupported file type for add_knowledge; use .json or .md")
321+
# load from JSON file
322+
with open(source, 'r', encoding='utf-8') as f:
323+
data = json.load(f)
324+
if not isinstance(data, list):
325+
raise ValueError("Expected a JSON file containing a list of examples")
326+
for ex in data:
327+
if isinstance(ex, dict):
328+
# assume already parsed JSON examples
329+
self.knowledge_base.append(ex)
345330
elif isinstance(source, list):
331+
# parse each entry (files/URLs/raw text) into text blobs
346332
for ex in source:
347333
if not isinstance(ex, dict):
348-
raise ValueError("Each item in list must be a dict")
334+
continue
349335
parsed = self._parse_entry(ex)
350336
if parsed:
351337
self.knowledge_base.append(parsed)
352338
else:
353-
raise ValueError("add_knowledge requires a JSON path, MD path, or list of dicts")
339+
raise ValueError(
340+
"add_knowledge() requires a JSON file path or a list of example-dicts"
341+
)
354342

355343
def save_knowledge(self, json_path: str) -> None:
356344
"""
357-
Write current knowledge_base (list of dicts) to a JSON file.
345+
Write current knowledge_base (a list of dicts) to a JSON file.
358346
"""
359347
with open(json_path, 'w', encoding='utf-8') as f:
360348
json.dump(self.knowledge_base, f, indent=2, ensure_ascii=False)
361-
362349

363350
@staticmethod
364351
def _strip_code_fences(text: str) -> str:

0 commit comments

Comments
 (0)