Skip to content

Commit 3256ea9

Browse files
ellisonbgdlqqq
andauthored
Add error handling for persona loading failures (#1397)
* Add ychat based error handling on persona loading. * Precommit fixes. * fix mypy & unit tests --------- Co-authored-by: David L. Qiu <[email protected]>
1 parent 52de802 commit 3256ea9

File tree

2 files changed

+103
-57
lines changed

2 files changed

+103
-57
lines changed

packages/jupyter-ai/jupyter_ai/personas/persona_manager.py

Lines changed: 93 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import inspect
66
import os
77
import sys
8+
import traceback
89
from glob import glob
910
from logging import Logger
1011
from pathlib import Path
1112
from time import time_ns
1213
from typing import TYPE_CHECKING, Any
1314

1415
from importlib_metadata import entry_points
15-
from jupyterlab_chat.models import Message
16+
from jupyterlab_chat.models import Message, NewMessage
1617
from jupyterlab_chat.ychat import YChat
1718
from traitlets.config import LoggingConfigurable
1819

@@ -52,11 +53,10 @@ class PersonaManager(LoggingConfigurable):
5253
type for type checkers.
5354
"""
5455

55-
# TODO: the Persona classes from entry points should be stored as a class
56-
# attribute, since they will not change at runtime.
57-
# That should be injected into this instance attribute when personas defined
58-
# under `.jupyter` are loaded.
59-
_persona_classes: list[type[BasePersona]] | None = None
56+
# We treat this as a class attribute so that we only have to load them once
57+
_ep_persona_classes: list[dict] | None = None
58+
# Local persona classes are instance attributes to support frequent reloading
59+
_local_persona_classes: list[dict] | None = None
6060
_personas: dict[str, BasePersona]
6161
file_id: str
6262

@@ -89,26 +89,22 @@ def __init__(
8989

9090
# Initialize MCP config loader
9191
self._mcp_config_loader = MCPConfigLoader()
92-
93-
# Load persona classes from entry points.
94-
# This is stored in a class attribute (global to all instances) because
95-
# the entry points are immutable after the server starts, so they only
96-
# need to be loaded once.
97-
if not isinstance(self._persona_classes, list):
98-
self._init_persona_classes()
99-
assert isinstance(self._persona_classes, list)
100-
92+
self._init_persona_classes()
93+
self.log.info("Persona classes loaded!")
10194
self._personas = self._init_personas()
95+
self.log.info("Personas created fully!")
10296

10397
def _init_persona_classes(self) -> None:
98+
"""Read entry-point and local persona classes."""
99+
if PersonaManager._ep_persona_classes is None:
100+
self._init_ep_persona_classes()
101+
assert isinstance(PersonaManager._ep_persona_classes, list)
102+
self._init_local_persona_classes()
103+
104+
def _init_ep_persona_classes(self) -> None:
104105
"""
105106
Initializes the list of persona *classes* by retrieving the
106107
`jupyter-ai.personas` entry points group.
107-
108-
# TODO: fix this part of docs now that we have it as an instance attr.
109-
This list is cached in the `self._persona_classes` instance
110-
attribute, .e. this method should only run once in the extension
111-
lifecycle.
112108
"""
113109
# Loading is in two parts:
114110
# 1. Load persona classes from package entry points.
@@ -122,14 +118,20 @@ def _init_persona_classes(self) -> None:
122118
self.log.info(f"Found {len(persona_eps)} entry points under '{EPG_NAME}'.")
123119
self.log.info("PENDING: Loading AI persona classes from entry points...")
124120
start_time_ns = time_ns()
125-
persona_classes: list[type[BasePersona]] = []
121+
persona_classes: list[dict] = []
126122

127123
for persona_ep in persona_eps:
128124
try:
129125
# Load a persona class from each entry point
130126
persona_class = persona_ep.load()
131127
assert issubclass(persona_class, BasePersona)
132-
persona_classes.append(persona_class)
128+
persona_classes.append(
129+
{
130+
"module": persona_ep.name,
131+
"persona_class": persona_class,
132+
"traceback": None,
133+
}
134+
)
133135
class_module, class_name = persona_ep.value.split(":")
134136
self.log.info(
135137
f" - Loaded AI persona class '{class_name}' from '{class_module}' using entry point '{persona_ep.name}'."
@@ -138,13 +140,21 @@ def _init_persona_classes(self) -> None:
138140
# On exception, log an error and continue.
139141
# This does not stop the surrounding `for` loop. If a persona
140142
# fails to load, it should not halt other personas from loading.
143+
tb_str = traceback.format_exc()
141144
self.log.exception(
142-
f" - Unable to load AI persona from entry point `{persona_ep.name}` due to an exception printed below."
145+
f" - Unable to load AI persona from entry point `{persona_ep.name}` due to an exception printed below.\n{tb_str}"
146+
)
147+
persona_classes.append(
148+
{
149+
"module": persona_ep.name,
150+
"persona_class": None,
151+
"traceback": tb_str,
152+
}
143153
)
144154
continue
145155

146156
if len(persona_classes) > 0:
147-
elapsed_time_ms = (time_ns() - start_time_ns) // 1000
157+
elapsed_time_ms = (time_ns() - start_time_ns) // 1000000
148158
self.log.info(
149159
f"SUCCESS: Loaded {len(persona_classes)} AI persona classes from entry points. Time elapsed: {elapsed_time_ms}ms."
150160
)
@@ -154,22 +164,27 @@ def _init_persona_classes(self) -> None:
154164
+ "Please verify your server configuration and open a new issue on our GitHub repo if this warning persists."
155165
)
156166

157-
# Load persona classes from local filesystem
167+
PersonaManager._ep_persona_classes = persona_classes
168+
169+
def _init_local_persona_classes(self) -> None:
170+
"""Load persona classes from local filesystem."""
158171
dotjupyter_dir = self.get_dotjupyter_dir()
159172
if dotjupyter_dir is None:
160173
self.log.info("No .jupyter directory found for loading local personas.")
161174
else:
162-
persona_classes.extend(load_from_dir(dotjupyter_dir, self.log))
163-
164-
self._persona_classes = persona_classes
175+
self._local_persona_classes = load_from_dir(dotjupyter_dir, self.log)
165176

166177
def _init_personas(self) -> dict[str, BasePersona]:
167178
"""
168179
Initializes the list of persona instances for the YChat instance passed
169180
to the constructor.
170181
"""
171182
# Ensure that persona classes were initialized first
172-
persona_classes = self._persona_classes
183+
persona_classes = []
184+
if isinstance(PersonaManager._ep_persona_classes, list):
185+
persona_classes.extend(PersonaManager._ep_persona_classes)
186+
if isinstance(self._local_persona_classes, list):
187+
persona_classes.extend(self._local_persona_classes)
173188
assert isinstance(persona_classes, list)
174189

175190
# If no persona classes are available, log a warning and return
@@ -183,7 +198,13 @@ def _init_personas(self) -> dict[str, BasePersona]:
183198
start_time_ns = time_ns()
184199

185200
personas: dict[str, BasePersona] = {}
186-
for Persona in persona_classes:
201+
for item in persona_classes:
202+
item.get("module")
203+
Persona = item.get("persona_class")
204+
tb = item.get("traceback")
205+
if Persona is None or tb is not None:
206+
self._display_persona_error_message(item)
207+
continue
187208
try:
188209
persona = Persona(
189210
parent=self,
@@ -192,10 +213,18 @@ def _init_personas(self) -> dict[str, BasePersona]:
192213
message_interrupted=self.message_interrupted,
193214
)
194215
except Exception:
216+
tb_str = traceback.format_exc()
195217
self.log.exception(
196218
f"The persona provided by `{Persona.__module__}` "
197-
"raised an exception while initializing, "
198-
"printed below."
219+
f"raised an exception while instantiating, "
220+
f"printed below.\n {tb_str}"
221+
)
222+
self._display_persona_error_message(
223+
{
224+
"module": Persona.__module__,
225+
"persona_class": Persona,
226+
"traceback": tb_str,
227+
}
199228
)
200229
continue
201230

@@ -218,6 +247,13 @@ def _init_personas(self) -> dict[str, BasePersona]:
218247
)
219248
return personas
220249

250+
def _display_persona_error_message(self, persona_item: dict) -> None:
251+
tb = persona_item.get("traceback")
252+
if tb is None:
253+
return
254+
body = f"Loading an AI persona raised an exception:\n\n```python\n{tb}```"
255+
self.ychat.add_message(NewMessage(body=body, sender="PersonaManager"))
256+
221257
@property
222258
def personas(self) -> dict[str, BasePersona]:
223259
"""
@@ -310,35 +346,35 @@ def get_mcp_config(self) -> dict[str, Any]:
310346
return self._mcp_config_loader.get_config(jdir)
311347

312348

313-
def load_from_dir(root_dir: str, log: Logger) -> list[type[BasePersona]]:
349+
def load_from_dir(dir: str, log: Logger) -> list[dict]:
314350
"""
315351
Load _persona class declarations_ from Python files in the local filesystem.
316352
317353
Those class declarations are then used to instantiate personas by the
318354
`PersonaManager`.
319355
320-
Scans the root_dir for .py files containing `persona` in their name that do
356+
Scans the dir for .py files containing `persona` in their name that do
321357
_not_ start with a single `_` (i.e. private modules are skipped). Then, it
322358
dynamically imports them, and extracts any class declarations that are
323359
subclasses of `BasePersona`.
324360
325-
Args:
326-
root_dir: Directory to scan for persona Python files.
327-
log: Logger instance for logging messages.
361+
Args:
362+
dir: Directory to scan for persona Python files.
363+
log: Logger instance for logging messages.
328364
329365
Returns:
330-
List of `BasePersona` subclasses found in the directory.
366+
List of `BasePersona` subclasses found in the directory.
331367
"""
332-
persona_classes: list[type[BasePersona]] = []
368+
persona_classes: list[dict] = []
333369

334-
log.info(f"Searching for persona files in {root_dir}")
370+
log.info(f"Searching for persona files in {dir}")
335371
# Check if root directory exists
336-
if not os.path.exists(root_dir):
372+
if not os.path.exists(dir):
337373
return persona_classes
338374

339375
# Find all .py files in the root directory that contain "persona" in the name
340376
try:
341-
all_py_files = glob(os.path.join(root_dir, "*.py"))
377+
all_py_files = glob(os.path.join(dir, "*.py"))
342378
py_files = []
343379
for f in all_py_files:
344380
fname_lower = Path(f).stem.lower()
@@ -350,17 +386,17 @@ def load_from_dir(root_dir: str, log: Logger) -> list[type[BasePersona]]:
350386
except Exception as e:
351387
# On exception with glob operation, return empty list
352388
log.error(
353-
f"{type(e).__name__} occurred while searching for Python files in {root_dir}"
389+
f"{type(e).__name__} occurred while searching for Python files in {dir}"
354390
)
355391
return persona_classes
356392

357393
if py_files:
358-
log.info(f"Found files from {root_dir}: {[Path(f).name for f in py_files]}")
394+
log.info(f"Found files from {dir}: {[Path(f).name for f in py_files]}")
359395

360396
# Temporarily add root_dir to sys.path for imports
361-
root_dir_in_path = root_dir in sys.path
362-
if not root_dir_in_path:
363-
sys.path.insert(0, root_dir)
397+
dir_in_path = dir in sys.path
398+
if not dir_in_path:
399+
sys.path.insert(0, dir)
364400

365401
try:
366402
# For each .py file, dynamically import the module and extract all
@@ -387,17 +423,23 @@ def load_from_dir(root_dir: str, log: Logger) -> list[type[BasePersona]]:
387423
and obj.__module__ == module_name
388424
):
389425
log.info(f"Found persona class '{obj.__name__}' in '{py_file}'")
390-
persona_classes.append(obj)
426+
persona_classes.append(
427+
{"module": py_file, "persona_class": obj, "traceback": None}
428+
)
391429

392430
except Exception:
393431
# On exception, log error and continue to next file
432+
tb_str = traceback.format_exc()
394433
log.exception(
395-
f"Unable to load persona classes from '{py_file}', exception details printed below."
434+
f"Unable to load persona classes from '{py_file}', exception details printed below.\n{tb_str}"
435+
)
436+
persona_classes.append(
437+
{"module": py_file, "persona_class": None, "traceback": tb_str}
396438
)
397439
continue
398440
finally:
399441
# Remove root_dir from sys.path if we added it
400-
if not root_dir_in_path and root_dir in sys.path:
401-
sys.path.remove(root_dir)
442+
if not dir_in_path and dir in sys.path:
443+
sys.path.remove(dir)
402444

403445
return persona_classes

packages/jupyter-ai/jupyter_ai/tests/test_personas.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,19 @@ def process_message(self, message):
6363
result = load_from_dir(str(tmp_persona_dir), mock_logger)
6464

6565
assert len(result) == 1
66-
assert result[0].__name__ == "TestPersona"
67-
assert issubclass(result[0], BasePersona)
66+
assert result[0]["persona_class"].__name__ == "TestPersona"
67+
assert issubclass(result[0]["persona_class"], BasePersona)
68+
assert result[0]["traceback"] is None
6869

69-
def test_bad_persona_file_returns_empty_list(self, tmp_persona_dir, mock_logger):
70-
"""Test that a file with syntax errors returns empty list."""
70+
def test_bad_persona_file_returns_error_entry(self, tmp_persona_dir, mock_logger):
71+
"""Test that a file with syntax errors returns an error entry."""
7172
# Create a file with invalid Python code
7273
bad_persona_file = tmp_persona_dir / "bad_persona.py"
73-
bad_persona_file.write_text("1/0")
74+
bad_persona_file.write_text("1/0 # This will cause a syntax error")
7475

7576
result = load_from_dir(str(tmp_persona_dir), mock_logger)
7677

77-
assert result == []
78+
assert len(result) == 1
79+
assert result[0]["persona_class"] is None
80+
assert result[0]["traceback"] is not None
81+
assert "ZeroDivisionError" in result[0]["traceback"]

0 commit comments

Comments
 (0)