Skip to content

Commit d143126

Browse files
bkrabachrampartemicrosoft-amplifier
authored
fix: catch BaseException in cleanup loops to survive asyncio.CancelledError (#12)
asyncio.CancelledError is a BaseException subclass (Python 3.9+), so `except Exception` in kernel loop sites lets it escape, aborting remaining cleanup functions, hook handlers, and contribution collection. Different patterns per site semantics: - collect_contributions: catch CancelledError → break (partial results) - cleanup loops: catch BaseException → continue, re-raise fatal after - session.cleanup: try/finally ensures loader.cleanup always runs - session.execute: catch BaseException for status tracking (re-raises) - hooks emit/emit_and_collect: catch CancelledError → log, continue - cancellation.trigger_callbacks: catch CancelledError → continue Includes 10 new tests for CancelledError resilience across all sites. Inspired-by: ramparte (PR #11) Co-Authored-By: ramparte <ramparte@users.noreply.github.com> Co-Authored-By: Amplifier <240397093+microsoft-amplifier@users.noreply.github.com>
1 parent f57a59a commit d143126

File tree

5 files changed

+421
-25
lines changed

5 files changed

+421
-25
lines changed

amplifier_core/cancellation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
The app layer provides the POLICY (when to cancel).
66
"""
77

8+
import asyncio
9+
import logging
810
from dataclasses import dataclass, field
911
from enum import Enum
1012
from typing import TYPE_CHECKING, Awaitable, Callable, Set
@@ -161,8 +163,22 @@ def on_cancel(self, callback: Callable[[], Awaitable[None]]) -> None:
161163

162164
async def trigger_callbacks(self) -> None:
163165
"""Trigger all registered cancellation callbacks."""
166+
_logger = logging.getLogger(__name__)
167+
first_fatal = None
164168
for callback in self._on_cancel_callbacks:
165169
try:
166170
await callback()
171+
except asyncio.CancelledError:
172+
# CancelledError is a BaseException (Python 3.9+). Log and continue
173+
# so all cancellation callbacks run.
174+
_logger.warning("CancelledError in cancellation callback")
167175
except Exception:
168176
pass # Don't let callback errors prevent cancellation
177+
except BaseException as e:
178+
# Track fatal exceptions (KeyboardInterrupt, SystemExit) for re-raise
179+
# after all callbacks complete.
180+
_logger.warning(f"Fatal exception in cancellation callback: {e}")
181+
if first_fatal is None:
182+
first_fatal = e
183+
if first_fatal is not None:
184+
raise first_fatal

amplifier_core/coordinator.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
identifiers and basic state necessary to make module boundaries work.
1313
"""
1414

15+
import asyncio
1516
import inspect
1617
import logging
1718
from collections.abc import Awaitable
@@ -335,6 +336,14 @@ async def collect_contributions(self, channel: str) -> list[Any]:
335336

336337
if result is not None:
337338
contributions.append(result)
339+
except asyncio.CancelledError:
340+
# CancelledError is a BaseException (Python 3.9+) - catch specifically.
341+
# Stop collecting (honor cancellation signal) and return what we have.
342+
logger.warning(
343+
f"Collection cancelled during contributor "
344+
f"'{contributor['name']}' on channel '{channel}'"
345+
)
346+
break
338347
except Exception as e:
339348
logger.warning(
340349
f"Contributor '{contributor['name']}' on channel '{channel}' failed: {e}"
@@ -344,6 +353,7 @@ async def collect_contributions(self, channel: str) -> list[Any]:
344353

345354
async def cleanup(self):
346355
"""Call all registered cleanup functions."""
356+
first_fatal = None
347357
for cleanup_fn in reversed(self._cleanup_functions):
348358
try:
349359
if callable(cleanup_fn):
@@ -353,8 +363,16 @@ async def cleanup(self):
353363
result = cleanup_fn()
354364
if inspect.iscoroutine(result):
355365
await result
356-
except Exception as e:
366+
except BaseException as e:
367+
# Catch BaseException to survive asyncio.CancelledError (a BaseException
368+
# subclass since Python 3.9) so remaining cleanup functions still run.
369+
# Track fatal exceptions (KeyboardInterrupt, SystemExit) for re-raise
370+
# after all cleanup completes.
357371
logger.error(f"Error during cleanup: {e}")
372+
if first_fatal is None and not isinstance(e, Exception):
373+
first_fatal = e
374+
if first_fatal is not None:
375+
raise first_fatal
358376

359377
def reset_turn(self):
360378
"""Reset per-turn tracking. Call at turn boundaries."""

amplifier_core/hooks.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,24 @@ def register(
7070
Returns:
7171
Unregister function
7272
"""
73-
hook_handler = HookHandler(handler=handler, priority=priority, name=name or handler.__name__)
73+
hook_handler = HookHandler(
74+
handler=handler, priority=priority, name=name or handler.__name__
75+
)
7476

7577
self._handlers[event].append(hook_handler)
7678
self._handlers[event].sort() # Keep sorted by priority
7779

78-
logger.debug(f"Registered hook '{hook_handler.name}' for event '{event}' with priority {priority}")
80+
logger.debug(
81+
f"Registered hook '{hook_handler.name}' for event '{event}' with priority {priority}"
82+
)
7983

8084
def unregister():
8185
"""Remove this handler from the registry."""
8286
if hook_handler in self._handlers[event]:
8387
self._handlers[event].remove(hook_handler)
84-
logger.debug(f"Unregistered hook '{hook_handler.name}' from event '{event}'")
88+
logger.debug(
89+
f"Unregistered hook '{hook_handler.name}' from event '{event}'"
90+
)
8591

8692
return unregister
8793

@@ -140,11 +146,15 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult:
140146
result = await hook_handler.handler(event, current_data)
141147

142148
if not isinstance(result, HookResult):
143-
logger.warning(f"Handler '{hook_handler.name}' returned invalid result type")
149+
logger.warning(
150+
f"Handler '{hook_handler.name}' returned invalid result type"
151+
)
144152
continue
145153

146154
if result.action == "deny":
147-
logger.info(f"Event '{event}' denied by handler '{hook_handler.name}': {result.reason}")
155+
logger.info(
156+
f"Event '{event}' denied by handler '{hook_handler.name}': {result.reason}"
157+
)
148158
return result
149159

150160
if result.action == "modify" and result.data is not None:
@@ -154,15 +164,27 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult:
154164
# Collect inject_context actions for merging
155165
if result.action == "inject_context" and result.context_injection:
156166
inject_context_results.append(result)
157-
logger.debug(f"Handler '{hook_handler.name}' returned inject_context")
167+
logger.debug(
168+
f"Handler '{hook_handler.name}' returned inject_context"
169+
)
158170

159171
# Preserve ask_user (only first one, can't merge approvals)
160172
if result.action == "ask_user" and special_result is None:
161173
special_result = result
162174
logger.debug(f"Handler '{hook_handler.name}' returned ask_user")
163175

176+
except asyncio.CancelledError:
177+
# CancelledError is a BaseException (Python 3.9+). Log and continue
178+
# so all handlers observe the event (important for cleanup events
179+
# like session:end that flow through emit).
180+
logger.error(
181+
f"CancelledError in hook handler '{hook_handler.name}' "
182+
f"for event '{event}'"
183+
)
164184
except Exception as e:
165-
logger.error(f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}")
185+
logger.error(
186+
f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}"
187+
)
166188
# Continue with other handlers even if one fails
167189

168190
# If multiple inject_context results, merge them.
@@ -173,7 +195,9 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult:
173195
merged_inject = self._merge_inject_context_results(inject_context_results)
174196
if special_result is None:
175197
special_result = merged_inject
176-
logger.debug(f"Merged {len(inject_context_results)} inject_context results")
198+
logger.debug(
199+
f"Merged {len(inject_context_results)} inject_context results"
200+
)
177201
else:
178202
# ask_user already captured - don't overwrite it
179203
logger.debug(
@@ -208,7 +232,9 @@ def _merge_inject_context_results(self, results: list[HookResult]) -> HookResult
208232
return results[0]
209233

210234
# Combine all injections
211-
combined_content = "\n\n".join(result.context_injection for result in results if result.context_injection)
235+
combined_content = "\n\n".join(
236+
result.context_injection for result in results if result.context_injection
237+
)
212238

213239
# Use settings from first result (role, ephemeral, suppress_output)
214240
first = results[0]
@@ -221,7 +247,9 @@ def _merge_inject_context_results(self, results: list[HookResult]) -> HookResult
221247
suppress_output=first.suppress_output,
222248
)
223249

224-
async def emit_and_collect(self, event: str, data: dict[str, Any], timeout: float = 1.0) -> list[Any]:
250+
async def emit_and_collect(
251+
self, event: str, data: dict[str, Any], timeout: float = 1.0
252+
) -> list[Any]:
225253
"""
226254
Emit event and collect data from all handler responses.
227255
@@ -247,27 +275,46 @@ async def emit_and_collect(self, event: str, data: dict[str, Any], timeout: floa
247275
logger.debug(f"No handlers for event '{event}'")
248276
return []
249277

250-
logger.debug(f"Collecting responses for event '{event}' from {len(handlers)} handlers")
278+
logger.debug(
279+
f"Collecting responses for event '{event}' from {len(handlers)} handlers"
280+
)
251281

252282
responses = []
253283
for hook_handler in handlers:
254284
try:
255285
# Call handler with timeout
256-
result = await asyncio.wait_for(hook_handler.handler(event, data), timeout=timeout)
286+
result = await asyncio.wait_for(
287+
hook_handler.handler(event, data), timeout=timeout
288+
)
257289

258290
if not isinstance(result, HookResult):
259-
logger.warning(f"Handler '{hook_handler.name}' returned invalid result type")
291+
logger.warning(
292+
f"Handler '{hook_handler.name}' returned invalid result type"
293+
)
260294
continue
261295

262296
# Collect response data if present
263297
if result.data is not None:
264298
responses.append(result.data)
265-
logger.debug(f"Collected response from handler '{hook_handler.name}'")
299+
logger.debug(
300+
f"Collected response from handler '{hook_handler.name}'"
301+
)
266302

267303
except TimeoutError:
268-
logger.warning(f"Handler '{hook_handler.name}' timed out after {timeout}s")
304+
logger.warning(
305+
f"Handler '{hook_handler.name}' timed out after {timeout}s"
306+
)
307+
except asyncio.CancelledError:
308+
# CancelledError is a BaseException (Python 3.9+). Log and continue
309+
# so all handlers get a chance to respond.
310+
logger.error(
311+
f"CancelledError in hook handler '{hook_handler.name}' "
312+
f"for event '{event}'"
313+
)
269314
except Exception as e:
270-
logger.error(f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}")
315+
logger.error(
316+
f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}"
317+
)
271318
# Continue with other handlers
272319

273320
logger.debug(f"Collected {len(responses)} responses for event '{event}'")
@@ -286,4 +333,7 @@ def list_handlers(self, event: str | None = None) -> dict[str, list[str]]:
286333
if event:
287334
handlers = self._handlers.get(event, [])
288335
return {event: [h.name for h in handlers if h.name is not None]}
289-
return {evt: [h.name for h in handlers if h.name is not None] for evt, handlers in self._handlers.items()}
336+
return {
337+
evt: [h.name for h in handlers if h.name is not None]
338+
for evt, handlers in self._handlers.items()
339+
}

amplifier_core/session.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
logger = logging.getLogger(__name__)
2121

2222

23-
def _safe_exception_str(e: Exception) -> str:
23+
def _safe_exception_str(e: BaseException) -> str:
2424
"""
2525
CRITICAL: Explicitly handle exception string conversion for Windows cp1252 compatibility.
2626
Default encoding can fail on non-cp1252 characters, causing a crash during error handling.
@@ -432,8 +432,9 @@ async def execute(self, prompt: str) -> str:
432432
self.status.status = "completed"
433433
return result
434434

435-
except Exception as e:
436-
# Check if this was a cancellation-related exception
435+
except BaseException as e:
436+
# Catch BaseException to handle asyncio.CancelledError (a BaseException
437+
# subclass since Python 3.9). All paths re-raise after status tracking.
437438
if self.coordinator.cancellation.is_cancelled:
438439
self.status.status = "cancelled"
439440
from .events import CANCEL_COMPLETED
@@ -455,10 +456,13 @@ async def execute(self, prompt: str) -> str:
455456

456457
async def cleanup(self: "AmplifierSession") -> None:
457458
"""Clean up session resources."""
458-
await self.coordinator.cleanup()
459-
# Clean up sys.path modifications
460-
if self.loader:
461-
self.loader.cleanup()
459+
try:
460+
await self.coordinator.cleanup()
461+
finally:
462+
# Clean up sys.path modifications - must always run even if
463+
# coordinator cleanup raises (e.g., asyncio.CancelledError)
464+
if self.loader:
465+
self.loader.cleanup()
462466

463467
async def __aenter__(self: "AmplifierSession"):
464468
"""Async context manager entry."""

0 commit comments

Comments
 (0)