Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/structlog/_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import collections
import contextvars
import sys
import threading

from collections.abc import Callable
from typing import Any
Expand All @@ -28,7 +29,7 @@
NOTSET,
WARNING,
)
from .contextvars import _ASYNC_CALLING_STACK
from .contextvars import _ASYNC_CALLING_STACK, _ASYNC_CALLING_THREAD
from .typing import FilteringBoundLogger


Expand Down Expand Up @@ -60,6 +61,11 @@ async def aexception(
if kw.get("exc_info", True) is True:
kw["exc_info"] = sys.exc_info()

# Capture thread info before passing to executor
thread_id = threading.get_ident()
thread_name = threading.current_thread().name
thread_token = _ASYNC_CALLING_THREAD.set((thread_id, thread_name))

scs_token = _ASYNC_CALLING_STACK.set(sys._getframe().f_back) # type: ignore[arg-type]
ctx = contextvars.copy_context()
try:
Expand All @@ -69,6 +75,7 @@ async def aexception(
)
finally:
_ASYNC_CALLING_STACK.reset(scs_token)
_ASYNC_CALLING_THREAD.reset(thread_token)

return runner

Expand Down Expand Up @@ -173,6 +180,11 @@ async def ameth(self: Any, event: str, *args: Any, **kw: Any) -> Any:
"""
event = _maybe_interpolate(event, args)

# Capture thread info before passing to executor
thread_id = threading.get_ident()
thread_name = threading.current_thread().name
thread_token = _ASYNC_CALLING_THREAD.set((thread_id, thread_name))

scs_token = _ASYNC_CALLING_STACK.set(sys._getframe().f_back) # type: ignore[arg-type]
ctx = contextvars.copy_context()
try:
Expand All @@ -184,6 +196,7 @@ async def ameth(self: Any, event: str, *args: Any, **kw: Any) -> Any:
)
finally:
_ASYNC_CALLING_STACK.reset(scs_token)
_ASYNC_CALLING_THREAD.reset(thread_token)

meth.__name__ = name
ameth.__name__ = f"a{name}"
Expand Down Expand Up @@ -211,6 +224,11 @@ async def alog(
name = LEVEL_TO_NAME[level]
event = _maybe_interpolate(event, args)

# Capture thread info before passing to executor
thread_id = threading.get_ident()
thread_name = threading.current_thread().name
thread_token = _ASYNC_CALLING_THREAD.set((thread_id, thread_name))

scs_token = _ASYNC_CALLING_STACK.set(sys._getframe().f_back) # type: ignore[arg-type]
ctx = contextvars.copy_context()
try:
Expand All @@ -222,6 +240,7 @@ async def alog(
)
finally:
_ASYNC_CALLING_STACK.reset(scs_token)
_ASYNC_CALLING_THREAD.reset(thread_token)
return runner

meths: dict[str, Callable[..., Any]] = {"log": log, "alog": alog}
Expand Down
6 changes: 6 additions & 0 deletions src/structlog/contextvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
contextvars.ContextVar("_ASYNC_CALLING_STACK")
)

# Stores thread info captured at async call time.
# Value is a tuple of (thread_id: int, thread_name: str)
_ASYNC_CALLING_THREAD: contextvars.ContextVar[tuple[int, str]] = (
contextvars.ContextVar("_ASYNC_CALLING_THREAD")
)

# For proper isolation, we have to use a dict of ContextVars instead of a
# single ContextVar with a dict.
# See https://github.com/hynek/structlog/pull/302 for details.
Expand Down
15 changes: 13 additions & 2 deletions src/structlog/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from ._log_levels import NAME_TO_LEVEL, add_log_level
from ._utils import get_processname
from .contextvars import _ASYNC_CALLING_THREAD
from .tracebacks import ExceptionDictTransformer
from .typing import (
EventDict,
Expand Down Expand Up @@ -783,11 +784,21 @@ def _get_callsite_lineno(module: str, frame: FrameType) -> Any:


def _get_callsite_thread(module: str, frame: FrameType) -> Any:
return threading.get_ident()
# Use captured thread info from async calls if available
try:
thread_info = _ASYNC_CALLING_THREAD.get()
return thread_info[0]
except LookupError:
return threading.get_ident()


def _get_callsite_thread_name(module: str, frame: FrameType) -> Any:
return threading.current_thread().name
# Use captured thread info from async calls if available
try:
thread_info = _ASYNC_CALLING_THREAD.get()
return thread_info[1]
except LookupError:
return threading.current_thread().name


def _get_callsite_process(module: str, frame: FrameType) -> Any:
Expand Down
19 changes: 18 additions & 1 deletion src/structlog/stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
import logging
import sys
import threading
import warnings

from collections.abc import Callable, Collection, Iterable, Sequence
Expand All @@ -33,7 +34,11 @@
from ._base import BoundLoggerBase
from ._frames import _find_first_app_frame_and_name, _format_stack
from ._log_levels import LEVEL_TO_NAME, NAME_TO_LEVEL, add_log_level
from .contextvars import _ASYNC_CALLING_STACK, merge_contextvars
from .contextvars import (
_ASYNC_CALLING_STACK,
_ASYNC_CALLING_THREAD,
merge_contextvars,
)
from .exceptions import DropEvent
from .processors import StackInfoRenderer
from .typing import (
Expand Down Expand Up @@ -424,6 +429,11 @@ async def _dispatch_to_sync(
"""
Merge contextvars and log using the sync logger in a thread pool.
"""
# Capture thread info before passing to executor
thread_id = threading.get_ident()
thread_name = threading.current_thread().name
thread_token = _ASYNC_CALLING_THREAD.set((thread_id, thread_name))

scs_token = _ASYNC_CALLING_STACK.set(sys._getframe().f_back.f_back) # type: ignore[union-attr, arg-type, unused-ignore]
ctx = contextvars.copy_context()

Expand All @@ -434,6 +444,7 @@ async def _dispatch_to_sync(
)
finally:
_ASYNC_CALLING_STACK.reset(scs_token)
_ASYNC_CALLING_THREAD.reset(thread_token)

async def adebug(self, event: str, *args: Any, **kw: Any) -> None:
"""
Expand Down Expand Up @@ -632,6 +643,11 @@ async def _dispatch_to_sync(
"""
Merge contextvars and log using the sync logger in a thread pool.
"""
# Capture thread info before passing to executor
thread_id = threading.get_ident()
thread_name = threading.current_thread().name
thread_token = _ASYNC_CALLING_THREAD.set((thread_id, thread_name))

scs_token = _ASYNC_CALLING_STACK.set(sys._getframe().f_back.f_back) # type: ignore[union-attr, arg-type, unused-ignore]
ctx = contextvars.copy_context()

Expand All @@ -642,6 +658,7 @@ async def _dispatch_to_sync(
)
finally:
_ASYNC_CALLING_STACK.reset(scs_token)
_ASYNC_CALLING_THREAD.reset(thread_token)

async def debug(self, event: str, *args: Any, **kw: Any) -> None:
await self._dispatch_to_sync(self.sync_bl.debug, event, args, kw)
Expand Down
48 changes: 47 additions & 1 deletion tests/processors/test_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def test_qual_name_logging_origin_absent(self) -> None:
async def test_async(self, wrapper_class, method_name) -> None:
"""
Callsite information for async invocations are correct.
Thread information is now correctly captured before async bridge.
"""
string_io = StringIO()

Expand All @@ -395,17 +396,62 @@ def __init__(self):

logger = structlog.stdlib.get_logger()

# Capture thread info before async call
expected_thread = threading.get_ident()
expected_thread_name = threading.current_thread().name

callsite_params = self.get_callsite_parameters()
await getattr(logger, method_name)("baz")
logger_params = json.loads(string_io.getvalue())

# These are different when running under async
# Thread info should now be correct (captured before async bridge)
assert logger_params["thread"] == expected_thread
assert logger_params["thread_name"] == expected_thread_name

# Remove thread info from comparison for remaining params
for key in ["thread", "thread_name"]:
callsite_params.pop(key)
logger_params.pop(key)

assert {"event": "baz", **callsite_params} == logger_params

@pytest.mark.asyncio
async def test_async_native_logger(self) -> None:
"""
Callsite thread information for native async invocations is correct.
"""
string_io = StringIO()

class StringIOLogger(structlog.PrintLogger):
def __init__(self):
super().__init__(file=string_io)

processor = CallsiteParameterAdder(
parameters=[
CallsiteParameter.THREAD,
CallsiteParameter.THREAD_NAME,
]
)
structlog.configure(
processors=[processor, JSONRenderer()],
logger_factory=StringIOLogger,
wrapper_class=structlog._native.BoundLoggerFilteringAtInfo,
cache_logger_on_first_use=True,
)

logger = structlog.get_logger()

# Capture thread info before async call
expected_thread = threading.get_ident()
expected_thread_name = threading.current_thread().name

await logger.ainfo("test native async")
logger_params = json.loads(string_io.getvalue())

# Thread info should now be correct (captured before async bridge)
assert logger_params["thread"] == expected_thread
assert logger_params["thread_name"] == expected_thread_name

def test_additional_ignores(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""
Stack frames from modules with names that start with values in
Expand Down