|
1 | | -import threading |
| 1 | +import sys |
2 | 2 | from contextlib import contextmanager |
3 | 3 | from typing import Any, Callable, Dict, TypeVar, cast |
4 | 4 |
|
5 | | -_scopes_local = threading.local() |
| 5 | +# Use contextvars if available (Python 3.7+), otherwise fall back to threading.local |
| 6 | +if sys.version_info >= (3, 7): |
| 7 | + import contextvars |
| 8 | + _context_stack: contextvars.ContextVar[list] = contextvars.ContextVar( |
| 9 | + 'posthog_context_stack', |
| 10 | + default=[{}] |
| 11 | + ) |
| 12 | + _use_contextvars = True |
| 13 | +else: |
| 14 | + import threading |
| 15 | + _scopes_local = threading.local() |
| 16 | + _use_contextvars = False |
6 | 17 |
|
7 | 18 |
|
8 | 19 | def _init_guard() -> None: |
9 | | - if not hasattr(_scopes_local, "context_stack"): |
| 20 | + if not _use_contextvars and not hasattr(_scopes_local, "context_stack"): |
10 | 21 | _scopes_local.context_stack = [{}] |
11 | 22 |
|
12 | 23 |
|
13 | 24 | def _get_current_context() -> Dict[str, Any]: |
14 | | - _init_guard() |
15 | | - return _scopes_local.context_stack[-1] |
| 25 | + if _use_contextvars: |
| 26 | + return _context_stack.get()[-1] |
| 27 | + else: |
| 28 | + _init_guard() |
| 29 | + return _scopes_local.context_stack[-1] |
16 | 30 |
|
17 | 31 |
|
18 | 32 | @contextmanager |
@@ -40,13 +54,22 @@ def new_context(): |
40 | 54 | posthog.capture_exception(e) |
41 | 55 | raise e |
42 | 56 | """ |
43 | | - _init_guard() |
44 | | - _scopes_local.context_stack.append({}) |
45 | | - try: |
46 | | - yield |
47 | | - finally: |
48 | | - if len(_scopes_local.context_stack) > 1: |
49 | | - _scopes_local.context_stack.pop() |
| 57 | + if _use_contextvars: |
| 58 | + current_stack = _context_stack.get() |
| 59 | + new_stack = current_stack + [{}] |
| 60 | + token = _context_stack.set(new_stack) |
| 61 | + try: |
| 62 | + yield |
| 63 | + finally: |
| 64 | + _context_stack.reset(token) |
| 65 | + else: |
| 66 | + _init_guard() |
| 67 | + _scopes_local.context_stack.append({}) |
| 68 | + try: |
| 69 | + yield |
| 70 | + finally: |
| 71 | + if len(_scopes_local.context_stack) > 1: |
| 72 | + _scopes_local.context_stack.pop() |
50 | 73 |
|
51 | 74 |
|
52 | 75 | def tag(key: str, value: Any) -> None: |
|
0 commit comments