Skip to content

Implement tracing decorator #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ COPY --chown=nonroot:nonroot poetry.lock .
COPY --chown=nonroot:nonroot src/alembic ./alembic
COPY --chown=nonroot:nonroot src/domains ./domains
COPY --chown=nonroot:nonroot src/gateways ./gateways
COPY --chown=nonroot:nonroot src/bootstrap ./bootstrap
COPY --chown=nonroot:nonroot src/common ./bootstrap
COPY --chown=nonroot:nonroot src/alembic.ini .
COPY --chown=nonroot:nonroot Makefile .

Expand Down
4 changes: 2 additions & 2 deletions docs/inversion-of-control.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def book_repository_factory() -> BookRepositoryInterface:

# file `domains/books/_service.py`
from domains.books._gateway_interfaces import BookRepositoryInterface
from bootstrap.factories import book_repository_factory
from common.factories import book_repository_factory


class BookService:
Expand Down Expand Up @@ -274,7 +274,7 @@ def inject_book_repository(f):
def wrapper(*args, **kwds):
# This allows overriding the decorator
if "book_repository" not in kwds.keys():
from bootstrap.storage import BookRepository
from common.storage import BookRepository
kwds["book_repository"] = BookRepository()
elif not isinstance(kwds["book_repository"], BookRepositoryInterface):
import warnings
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ branch = true
source = ["src"]
omit = [
"src/alembic/*",
"src/bootstrap/config.py",
"src/bootstrap/logs/*",
"src/common/config.py",
"src/common/logs/*",
]
# It's not necessary to configure concurrency here
# because pytest-cov takes care of that
Expand Down
4 changes: 2 additions & 2 deletions src/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlalchemy.ext.asyncio import AsyncEngine

from alembic import context
from bootstrap.bootstrap import application_init
from bootstrap.config import AppConfig
from common.bootstrap import application_init
from common.config import AppConfig

USE_TWOPHASE = False

Expand Down
2 changes: 1 addition & 1 deletion src/celery_worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from celery.signals import worker_process_init
from opentelemetry.instrumentation.celery import CeleryInstrumentor

from bootstrap import AppConfig, application_init
from common import AppConfig, application_init


@worker_process_init.connect(weak=False)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy_bind_manager import SQLAlchemyBindManager
from sqlalchemy_bind_manager.repository import SQLAlchemyAsyncRepository

from bootstrap.config import AppConfig
from common.config import AppConfig
from domains.books._gateway_interfaces import (
BookEventGatewayInterface,
BookRepositoryInterface,
Expand All @@ -21,7 +21,7 @@ class Container(DeclarativeContainer):

wiring_config = WiringConfiguration(
packages=[
"bootstrap",
"common",
"domains",
]
)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
56 changes: 56 additions & 0 deletions src/common/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import asyncio
from functools import wraps

from opentelemetry import trace

# Get the _tracer instance (You can set your own _tracer name)
tracer = trace.get_tracer(__name__)


def trace_function(trace_attributes: bool = True, trace_result: bool = True):
"""
Decorator to trace callables using OpenTelemetry spans.

Parameters:
- trace_attributes (bool): If False, disables adding function arguments to the span.
- trace_result (bool): If False, disables adding the function's result to the span.
"""

def decorator(func):
@wraps(func)
def sync_or_async_wrapper(*args, **kwargs):
with tracer.start_as_current_span(func.__name__) as span:
try:
# Set function arguments as attributes
if trace_attributes:
span.set_attribute("function.args", str(args))
span.set_attribute("function.kwargs", str(kwargs))

async def async_handler():
result = await func(*args, **kwargs)
# Add result to span
if trace_result:
span.set_attribute("function.result", str(result))
return result

def sync_handler():
result = func(*args, **kwargs)
# Add result to span
if trace_result:
span.set_attribute("function.result", str(result))
return result

if asyncio.iscoroutinefunction(func):
return async_handler()
else:
return sync_handler()

except Exception as e:
# Record the exception in the span
span.record_exception(e)
span.set_status(trace.status.Status(trace.status.StatusCode.ERROR))
raise

return sync_or_async_wrapper

return decorator
26 changes: 26 additions & 0 deletions src/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
def apply_decorator_to_methods(
decorator, protected_methods: bool = False, private_methods: bool = False
):
"""
Class decorator to apply a given function or coroutine decorator
to all functions and coroutines within a class.
"""

def class_decorator(cls):
for attr_name, attr_value in cls.__dict__.items():
# Check if the attribute is a callable (method or coroutine)
if not callable(attr_value):
continue

if attr_name.startswith(f"_{cls.__name__}__"):
if not private_methods:
continue

elif attr_name.startswith("_") and not protected_methods:
continue

# Replace the original callable with the decorated version
setattr(cls, attr_name, decorator(attr_value))
return cls

return class_decorator
4 changes: 4 additions & 0 deletions src/domains/books/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from dependency_injector.wiring import Provide, inject
from structlog import get_logger

from common.tracing import trace_function
from common.utils import apply_decorator_to_methods

from ._gateway_interfaces import BookEventGatewayInterface, BookRepositoryInterface
from ._models import BookModel
from ._tasks import book_cpu_intensive_task
from .dto import Book, BookData
from .events import BookCreatedV1, BookCreatedV1Data


@apply_decorator_to_methods(trace_function())
class BookService:
_book_repository: BookRepositoryInterface
_event_gateway: BookEventGatewayInterface
Expand Down
2 changes: 1 addition & 1 deletion src/http_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from starlette_prometheus import PrometheusMiddleware, metrics
from structlog import get_logger

from bootstrap import AppConfig, application_init
from common import AppConfig, application_init
from http_app.routes import init_routes


Expand Down
Empty file added tests/common/__init__.py
Empty file.
156 changes: 156 additions & 0 deletions tests/common/test_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import asyncio
from unittest.mock import MagicMock, call, patch

import pytest

from common.tracing import trace_function


@pytest.fixture
def mock_tracer():
"""
Fixture to mock the OpenTelemetry tracer and span.
"""
mock_tracer = MagicMock()
mock_span = MagicMock()
mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span

with (
patch("opentelemetry.trace.get_tracer", return_value=mock_tracer),
patch("common.tracing.tracer", mock_tracer),
):
yield mock_tracer, mock_span


def test_sync_function_default_params(mock_tracer):
"""
Test a synchronous function with default decorator parameters.
"""
mock_tracer, mock_span = mock_tracer

# Define a sync function to wrap with the decorator
@trace_function()
def add_nums(a, b):
return a + b

# Call the function
result = add_nums(2, 3)

# Assertions
assert result == 5
mock_tracer.start_as_current_span.assert_called_once_with("add_nums")
mock_span.set_attribute.assert_any_call("function.args", "(2, 3)")
mock_span.set_attribute.assert_any_call("function.result", "5")


@pytest.mark.asyncio
async def test_async_function_default_params(mock_tracer):
"""
Test an asynchronous function with default decorator parameters.
"""
mock_tracer, mock_span = mock_tracer

# Define an async function to wrap with the decorator
@trace_function()
async def async_func(a, b):
await asyncio.sleep(0.1)
return a * b

# Run the async function
result = await async_func(4, 5)

# Assertions
assert result == 20
mock_tracer.start_as_current_span.assert_called_once_with("async_func")
mock_span.set_attribute.assert_any_call("function.args", "(4, 5)")
mock_span.set_attribute.assert_any_call("function.result", "20")


def test_disable_function_attributes(mock_tracer):
"""
Test a synchronous function with `add_function_attributes` set to False.
"""
mock_tracer, mock_span = mock_tracer

# Define a sync function with attributes disabled
@trace_function(trace_attributes=False)
def sync_func(a, b):
return a - b

# Call the function
result = sync_func(10, 6)

# Assertions
assert result == 4
mock_tracer.start_as_current_span.assert_called_once_with("sync_func")
mock_span.set_attribute.assert_any_call("function.result", "4")
assert (
call("function.args", "(10, 6)") not in mock_span.set_attribute.call_args_list
)


def test_disable_result_in_span_sync(mock_tracer):
"""
Test an asynchronous function with `add_result_to_span` set to False.
"""
mock_tracer, mock_span = mock_tracer

# Define an async function with result disabled
@trace_function(trace_result=False)
def sync_func(a, b):
return a / b

# Run the async function
result = sync_func(10, 2)

# Assertions
assert result == 5.0
mock_tracer.start_as_current_span.assert_called_once_with("sync_func")
mock_span.set_attribute.assert_any_call("function.args", "(10, 2)")
assert call("function.result") not in mock_span.set_attribute.call_args_list


@pytest.mark.asyncio
async def test_disable_result_in_span(mock_tracer):
"""
Test an asynchronous function with `add_result_to_span` set to False.
"""
mock_tracer, mock_span = mock_tracer

# Define an async function with result disabled
@trace_function(trace_result=False)
async def async_func(a, b):
await asyncio.sleep(0.1)
return a / b

# Run the async function
result = await async_func(10, 2)

# Assertions
assert result == 5.0
mock_tracer.start_as_current_span.assert_called_once_with("async_func")
mock_span.set_attribute.assert_any_call("function.args", "(10, 2)")
assert call("function.result") not in mock_span.set_attribute.call_args_list


def test_exception_in_function(mock_tracer):
"""
Test behavior when the function raises an exception.
"""
mock_tracer, mock_span = mock_tracer

# Define a failing function
@trace_function()
def failing_func(a, b):
if b == 0:
raise ValueError("Division by zero!")
return a / b

# Use pytest to assert the exception is raised
with pytest.raises(ValueError, match="Division by zero!"):
failing_func(10, 0)

# Assertions
mock_tracer.start_as_current_span.assert_called_once_with("failing_func")
mock_span.record_exception.assert_called_once()
mock_span.set_status.assert_called_once()
Loading
Loading