diff --git a/logfire/_internal/cli/run.py b/logfire/_internal/cli/run.py index f62f61e2d..8a49ef638 100644 --- a/logfire/_internal/cli/run.py +++ b/logfire/_internal/cli/run.py @@ -171,7 +171,14 @@ def instrument_packages(installed_otel_packages: set[str], instrument_pkg_map: d def instrument_package(import_name: str): instrument_attr = f'instrument_{import_name}' - getattr(logfire, instrument_attr)() + + # On those packages, the public `logfire.instrument_` function needs to receive the app object. + # But the internal function doesn't need it. + if import_name in ('starlette', 'fastapi', 'flask'): + module = importlib.import_module(f'logfire._internal.integrations.{import_name}') + getattr(module, instrument_attr)(logfire.DEFAULT_LOGFIRE_INSTANCE) + else: + getattr(logfire, instrument_attr)() def find_recommended_instrumentations_to_install( diff --git a/logfire/_internal/integrations/fastapi.py b/logfire/_internal/integrations/fastapi.py index 165e1bce9..eb2997f47 100644 --- a/logfire/_internal/integrations/fastapi.py +++ b/logfire/_internal/integrations/fastapi.py @@ -50,7 +50,10 @@ def find_mounted_apps(app: FastAPI) -> list[FastAPI]: def instrument_fastapi( logfire_instance: Logfire, - app: FastAPI, + # Note that `Logfire.instrument_fastapi()` requires this argument. It's only omitted when called via the + # `logfire run` CLI. This is because `FastAPIInstrumentor.instrument() has to be called before + # `from fastapi import FastAPI` which is easy to get wrong. + app: FastAPI | None = None, *, capture_headers: bool = False, request_attributes_mapper: Callable[ @@ -77,44 +80,52 @@ def instrument_fastapi( maybe_capture_server_headers(capture_headers) opentelemetry_kwargs = { + 'excluded_urls': excluded_urls, + 'server_request_hook': _server_request_hook(opentelemetry_kwargs.pop('server_request_hook', None)), 'tracer_provider': tweak_asgi_spans_tracer_provider(logfire_instance, record_send_receive), 'meter_provider': logfire_instance.config.get_meter_provider(), **opentelemetry_kwargs, } - FastAPIInstrumentor.instrument_app( - app, - excluded_urls=excluded_urls, - server_request_hook=_server_request_hook(opentelemetry_kwargs.pop('server_request_hook', None)), - **opentelemetry_kwargs, - ) - registry = patch_fastapi() - if app in registry: # pragma: no cover - raise ValueError('This app has already been instrumented.') + if app is None: + FastAPIInstrumentor().instrument(**opentelemetry_kwargs) - mounted_apps = find_mounted_apps(app) - mounted_apps.append(app) + @contextmanager + def uninstrument_context(): # pragma: no cover + yield + FastAPIInstrumentor().uninstrument() - for _app in mounted_apps: - registry[_app] = FastAPIInstrumentation( - logfire_instance, - request_attributes_mapper or _default_request_attributes_mapper, - extra_spans=extra_spans, - ) + return uninstrument_context() + else: + FastAPIInstrumentor.instrument_app(app, **opentelemetry_kwargs) - @contextmanager - def uninstrument_context(): - # The user isn't required (or even expected) to use this context manager, - # which is why the instrumenting and patching has already happened before this point. - # It exists mostly for tests, and just in case users want it. - try: - yield - finally: - for _app in mounted_apps: - del registry[_app] - FastAPIInstrumentor.uninstrument_app(_app) + registry = patch_fastapi() + if app in registry: # pragma: no cover + raise ValueError('This app has already been instrumented.') + + mounted_apps = find_mounted_apps(app) + mounted_apps.append(app) + + for _app in mounted_apps: + registry[_app] = FastAPIInstrumentation( + logfire_instance, + request_attributes_mapper or _default_request_attributes_mapper, + extra_spans=extra_spans, + ) + + @contextmanager + def uninstrument_context(): + # The user isn't required (or even expected) to use this context manager, + # which is why the instrumenting and patching has already happened before this point. + # It exists mostly for tests, and just in case users want it. + try: + yield + finally: + for _app in mounted_apps: + del registry[_app] + FastAPIInstrumentor.uninstrument_app(_app) - return uninstrument_context() + return uninstrument_context() @lru_cache # only patch once diff --git a/logfire/_internal/integrations/flask.py b/logfire/_internal/integrations/flask.py index 55582b135..5a10a1320 100644 --- a/logfire/_internal/integrations/flask.py +++ b/logfire/_internal/integrations/flask.py @@ -15,16 +15,21 @@ " pip install 'logfire[flask]'" ) +from logfire import Logfire from logfire._internal.utils import maybe_capture_server_headers from logfire.integrations.flask import CommenterOptions, RequestHook, ResponseHook def instrument_flask( - app: Flask, + logfire_instance: Logfire, + # Note that `Logfire.instrument_flask()` requires this argument. It's only omitted when called via the + # `logfire run` CLI. This is because `FlaskInstrumentor.instrument_app()` has to be called before + # `from flask import Flask` which is easy to get wrong. + app: Flask | None = None, *, - capture_headers: bool, - enable_commenter: bool, - commenter_options: CommenterOptions | None, + capture_headers: bool = False, + enable_commenter: bool = False, + commenter_options: CommenterOptions | None = None, excluded_urls: str | None = None, request_hook: RequestHook | None = None, response_hook: ResponseHook | None = None, @@ -41,12 +46,18 @@ def instrument_flask( warn_at_user_stacklevel('exclude_urls is deprecated; use excluded_urls instead', DeprecationWarning) excluded_urls = excluded_urls or kwargs.pop('exclude_urls', None) - FlaskInstrumentor().instrument_app( # type: ignore[reportUnknownMemberType] - app, - enable_commenter=enable_commenter, - commenter_options=commenter_options, - excluded_urls=excluded_urls, - request_hook=request_hook, - response_hook=response_hook, + opentelemetry_kwargs = { + 'enable_commenter': enable_commenter, + 'commenter_options': commenter_options, + 'excluded_urls': excluded_urls, + 'request_hook': request_hook, + 'response_hook': response_hook, + 'tracer_provider': logfire_instance.config.get_tracer_provider(), + 'meter_provider': logfire_instance.config.get_meter_provider(), **kwargs, - ) + } + + if app is None: + FlaskInstrumentor().instrument(**opentelemetry_kwargs) + else: + FlaskInstrumentor().instrument_app(app, **opentelemetry_kwargs) # type: ignore[reportUnknownMemberType] diff --git a/logfire/_internal/integrations/starlette.py b/logfire/_internal/integrations/starlette.py index 5c7bcccfd..d20d74d76 100644 --- a/logfire/_internal/integrations/starlette.py +++ b/logfire/_internal/integrations/starlette.py @@ -21,7 +21,10 @@ def instrument_starlette( logfire_instance: Logfire, - app: Starlette, + # Note that `Logfire.instrument_starlette()` requires this argument. It's only omitted when called via the + # `logfire run` CLI. This is because `StarletteInstrumentor.instrument()` has to be called before + # `from starlette import Starlette` which is easy to get wrong. + app: Starlette | None = None, *, record_send_receive: bool = False, capture_headers: bool = False, @@ -35,14 +38,26 @@ def instrument_starlette( See the `Logfire.instrument_starlette` method for details. """ maybe_capture_server_headers(capture_headers) - StarletteInstrumentor().instrument_app( - app, - server_request_hook=server_request_hook, - client_request_hook=client_request_hook, - client_response_hook=client_response_hook, - **{ # type: ignore - 'tracer_provider': tweak_asgi_spans_tracer_provider(logfire_instance, record_send_receive), - 'meter_provider': logfire_instance.config.get_meter_provider(), - **kwargs, - }, - ) + if app is None: + StarletteInstrumentor().instrument( + server_request_hook=server_request_hook, + client_request_hook=client_request_hook, + client_response_hook=client_response_hook, + **{ + 'tracer_provider': tweak_asgi_spans_tracer_provider(logfire_instance, record_send_receive), + 'meter_provider': logfire_instance.config.get_meter_provider(), + **kwargs, + }, + ) + else: + StarletteInstrumentor().instrument_app( + app, + server_request_hook=server_request_hook, + client_request_hook=client_request_hook, + client_response_hook=client_response_hook, + **{ # type: ignore + 'tracer_provider': tweak_asgi_spans_tracer_provider(logfire_instance, record_send_receive), + 'meter_provider': logfire_instance.config.get_meter_provider(), + **kwargs, + }, + ) diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 102036960..44d38c4c4 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -11,15 +11,7 @@ from enum import Enum from functools import cached_property from time import time -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, - TypeVar, - Union, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload import opentelemetry.context as context_api import opentelemetry.trace as trace_api @@ -1589,6 +1581,7 @@ def instrument_flask( self._warn_if_not_initialized_for_instrumentation() return instrument_flask( + self, app, capture_headers=capture_headers, enable_commenter=enable_commenter, @@ -1596,11 +1589,7 @@ def instrument_flask( excluded_urls=excluded_urls, request_hook=request_hook, response_hook=response_hook, - **{ - 'tracer_provider': self._config.get_tracer_provider(), - 'meter_provider': self._config.get_meter_provider(), - **kwargs, - }, + **kwargs, ) def instrument_starlette( diff --git a/tests/test_cli.py b/tests/test_cli.py index 70f17fa0f..96b09c6ed 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -31,6 +31,7 @@ ) from logfire._internal.config import LogfireCredentials, sanitize_project_name from logfire.exceptions import LogfireConfigError +from logfire.testing import TestExporter from tests.import_used_for_tests import run_script_test @@ -1582,6 +1583,38 @@ async def test_instrument_packages_aiohttp_client() -> None: AioHttpClientInstrumentor().uninstrument() +def test_instrument_web_frameworks(exporter: TestExporter) -> None: + try: + instrument_packages( + { + 'opentelemetry-instrumentation-starlette', + 'opentelemetry-instrumentation-fastapi', + 'opentelemetry-instrumentation-flask', + }, + { + 'opentelemetry-instrumentation-starlette': 'starlette', + 'opentelemetry-instrumentation-fastapi': 'fastapi', + 'opentelemetry-instrumentation-flask': 'flask', + }, + ) + + from fastapi import FastAPI + from flask import Flask + from starlette.applications import Starlette + + assert getattr(Starlette(), '_is_instrumented_by_opentelemetry', False) is True + assert getattr(FastAPI(), '_is_instrumented_by_opentelemetry', False) is True + assert getattr(Flask(__name__), '_is_instrumented_by_opentelemetry', False) is True + finally: + from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor + from opentelemetry.instrumentation.flask import FlaskInstrumentor + from opentelemetry.instrumentation.starlette import StarletteInstrumentor + + StarletteInstrumentor().uninstrument() + FastAPIInstrumentor().uninstrument() + FlaskInstrumentor().uninstrument() + + def test_split_args_action() -> None: parser = argparse.ArgumentParser() parser.add_argument('--foo', action=SplitArgs)