diff --git a/ddtrace/_trace/processor/resource_renaming.py b/ddtrace/_trace/processor/resource_renaming.py index 6ee6c664f95..21fd635f4f5 100644 --- a/ddtrace/_trace/processor/resource_renaming.py +++ b/ddtrace/_trace/processor/resource_renaming.py @@ -4,6 +4,7 @@ from urllib.parse import urlparse from ddtrace._trace.processor import SpanProcessor +from ddtrace._trace.span import Span from ddtrace.ext import SpanTypes from ddtrace.ext import http from ddtrace.internal.logger import get_logger @@ -13,7 +14,7 @@ log = get_logger(__name__) -class ResourceRenamingProcessor(SpanProcessor): +class SimplifiedEndpointComputer: def __init__(self): self._INT_RE = re.compile(r"^[1-9][0-9]+$") self._INT_ID_RE = re.compile(r"^(?=.*[0-9].*)[0-9._-]{3,}$") @@ -35,7 +36,7 @@ def _compute_simplified_endpoint_path_element(self, elem: str) -> str: return "{param:str}" return elem - def _compute_simplified_endpoint(self, url: Optional[str]) -> str: + def from_url(self, url: Optional[str]) -> str: """Extracts and simplifies the path from an HTTP URL.""" if not url: return "/" @@ -62,10 +63,15 @@ def _compute_simplified_endpoint(self, url: Optional[str]) -> str: elements = [self._compute_simplified_endpoint_path_element(elem) for elem in elements] return "/" + "/".join(elements) - def on_span_start(self, span): + +class ResourceRenamingProcessor(SpanProcessor): + def __init__(self): + self.simplified_endpoint_computer = SimplifiedEndpointComputer() + + def on_span_start(self, span: Span): pass - def on_span_finish(self, span): + def on_span_finish(self, span: Span): if not span._is_top_level or span.span_type not in (SpanTypes.WEB, SpanTypes.HTTP, SpanTypes.SERVERLESS): return @@ -73,5 +79,5 @@ def on_span_finish(self, span): if not route or config._trace_resource_renaming_always_simplified_endpoint: url = span.get_tag(http.URL) - endpoint = self._compute_simplified_endpoint(url) + endpoint = self.simplified_endpoint_computer.from_url(url) span.set_tag_str(http.ENDPOINT, endpoint) diff --git a/ddtrace/appsec/_api_security/api_manager.py b/ddtrace/appsec/_api_security/api_manager.py index d05f07a1002..f21f92ec727 100644 --- a/ddtrace/appsec/_api_security/api_manager.py +++ b/ddtrace/appsec/_api_security/api_manager.py @@ -3,13 +3,16 @@ import gzip import json import time -from typing import Optional +from typing import Optional, Union from ddtrace._trace._limits import MAX_SPAN_META_VALUE_LEN +from ddtrace._trace.processor.resource_renaming import SimplifiedEndpointComputer +from ddtrace.appsec._asm_request_context import ASM_Environment from ddtrace.appsec._constants import API_SECURITY from ddtrace.appsec._constants import SPAN_DATA_NAMES from ddtrace.appsec._trace_utils import _asm_manual_keep import ddtrace.constants as constants +from ddtrace.ext import http from ddtrace.internal import logger as ddlogger from ddtrace.internal.service import Service from ddtrace.settings.asm import config as asm_config @@ -77,6 +80,7 @@ def __init__(self) -> None: log.debug("%s initialized", self.__class__.__name__) self._hashtable: collections.OrderedDict[int, float] = collections.OrderedDict() + self.simplified_endpoint_computer = SimplifiedEndpointComputer() import ddtrace.appsec._asm_request_context as _asm_request_context import ddtrace.appsec._metrics as _metrics @@ -91,7 +95,7 @@ def _stop_service(self) -> None: def _start_service(self) -> None: self._asm_context.add_context_callback(self._schema_callback, global_callback=True) - def _should_collect_schema(self, env, priority: int) -> Optional[bool]: + def _should_collect_schema(self, env: ASM_Environment, priority: int) -> Optional[bool]: """ Rate limit per route. @@ -104,8 +108,21 @@ def _should_collect_schema(self, env, priority: int) -> Optional[bool]: return False method = env.waf_addresses.get(SPAN_DATA_NAMES.REQUEST_METHOD) + status: Union[str, int] = env.waf_addresses.get(SPAN_DATA_NAMES.RESPONSE_STATUS) # type: ignore[assignment] + + try: + int_status = int(status) + except ValueError: + int_status = None + route = env.waf_addresses.get(SPAN_DATA_NAMES.REQUEST_ROUTE) - status = env.waf_addresses.get(SPAN_DATA_NAMES.RESPONSE_STATUS) + if route is None and int_status != 404: + endpoint = env.entry_span.get_tag(http.ENDPOINT) + if endpoint is None: + url = env.entry_span.get_tag(http.URL) + endpoint = self.simplified_endpoint_computer.from_url(url) + route = endpoint + # Framework is not fully supported if method is None or route is None or status is None: log.debug( diff --git a/tests/appsec/appsec/api_security/test_api_security_manager.py b/tests/appsec/appsec/api_security/test_api_security_manager.py index f18560f10e6..74ae12e5666 100644 --- a/tests/appsec/appsec/api_security/test_api_security_manager.py +++ b/tests/appsec/appsec/api_security/test_api_security_manager.py @@ -10,6 +10,7 @@ from ddtrace.constants import AUTO_REJECT from ddtrace.constants import USER_KEEP from ddtrace.constants import USER_REJECT +from ddtrace.ext import http from tests.utils import override_global_config @@ -251,3 +252,88 @@ def test_schema_callback_parse_response_body_disabled(self, api_manager, mock_en assert len(mock_environment.entry_span._meta) == 0 api_manager._metrics._report_api_security.assert_called_with(True, 0) + + def test_should_collect_schema_route_fallbacks_to_endpoint(self, mock_environment): + """Test that _should_collect_schema falls back to endpoint tags when route is missing.""" + with override_global_config( + values=dict( + _asm_enabled=True, + _api_security_enabled=True, + _apm_tracing_enabled=True, + _api_security_parse_response_body=True, + ) + ): + manager = APIManager() + manager._appsec_processor = MagicMock() + manager._asm_context = MagicMock() + manager._metrics = MagicMock() + + mock_environment.entry_span.get_tag = lambda name: "/span-endpoint" if name == http.ENDPOINT else None + mock_environment.waf_addresses = { + SPAN_DATA_NAMES.REQUEST_ROUTE: None, + SPAN_DATA_NAMES.REQUEST_METHOD: "GET", + SPAN_DATA_NAMES.RESPONSE_STATUS: 200, + } + + # First request should collect + assert manager._should_collect_schema(mock_environment, USER_KEEP) + # Sencond one should discarded + assert not manager._should_collect_schema(mock_environment, USER_KEEP) + + def test_should_collect_schema_route_missing_computes_endpoint(self, mock_environment): + """Test that _should_collect_schema computes the endpoint value when route and endpoint tags are missing.""" + with override_global_config( + values=dict( + _asm_enabled=True, + _api_security_enabled=True, + _apm_tracing_enabled=True, + _api_security_parse_response_body=True, + ) + ): + manager = APIManager() + manager._appsec_processor = MagicMock() + manager._asm_context = MagicMock() + manager._metrics = MagicMock() + + def get_tag(name): + return "https://ddtrace.dog/span-endpoint" if name == http.URL else None + + mock_environment.entry_span.get_tag = get_tag + mock_environment.waf_addresses = { + SPAN_DATA_NAMES.REQUEST_ROUTE: None, + SPAN_DATA_NAMES.REQUEST_METHOD: "GET", + SPAN_DATA_NAMES.RESPONSE_STATUS: 200, + } + + # First request should collect + assert manager._should_collect_schema(mock_environment, USER_KEEP) + # Sencond one should discarded + assert not manager._should_collect_schema(mock_environment, USER_KEEP) + + @pytest.mark.parametrize("status_code", [404, "404", "invalid"]) + def test_should_not_collect_schema_on_404(self, mock_environment, status_code): + """Test that _should_collect_schema computes the endpoint value when route and endpoint tags are missing.""" + with override_global_config( + values=dict( + _asm_enabled=True, + _api_security_enabled=True, + _apm_tracing_enabled=True, + _api_security_parse_response_body=True, + ) + ): + manager = APIManager() + manager._appsec_processor = MagicMock() + manager._asm_context = MagicMock() + manager._metrics = MagicMock() + + def get_tag(name): + return "https://ddtrace.dog/span-endpoint" if name == http.URL else None + + mock_environment.entry_span.get_tag = get_tag + mock_environment.waf_addresses = { + SPAN_DATA_NAMES.REQUEST_ROUTE: None, + SPAN_DATA_NAMES.REQUEST_METHOD: "GET", + SPAN_DATA_NAMES.RESPONSE_STATUS: status_code, + } + + assert not manager._should_collect_schema(mock_environment, USER_KEEP) diff --git a/tests/snapshots/tests.contrib.django.test_django_appsec_snapshots.test_appsec_enabled_attack.json b/tests/snapshots/tests.contrib.django.test_django_appsec_snapshots.test_appsec_enabled_attack.json index cc9425e47f2..4bb06885f9f 100644 --- a/tests/snapshots/tests.contrib.django.test_django_appsec_snapshots.test_appsec_enabled_attack.json +++ b/tests/snapshots/tests.contrib.django.test_django_appsec_snapshots.test_appsec_enabled_attack.json @@ -14,6 +14,10 @@ "_dd.appsec.fp.http.network": "net-0-0000000000", "_dd.appsec.fp.session": "ssn----", "_dd.appsec.rc_products": "[] u:0 r:2", + "_dd.appsec.s.req.cookies": "H4sIALPv+GgC/4uuro0FAPz6p+oEAAAA", + "_dd.appsec.s.req.headers": "H4sIALPv+GgC/4uuVkrOz8tLTS7JzM9Tsoq2iNVRSkxOTi0oQeHopuYl56dk5qVDRUuLU4t0E9NT82DKMvKLIczaWABTHLJjUwAAAA==", + "_dd.appsec.s.req.query": "H4sIALPv+GgC/4uuro0FAPz6p+oEAAAA", + "_dd.appsec.s.res.headers": "H4sIALPv+GgC/12MQQqAMAwE/9KzuYtfKT0UiTWgSUmDKOLfFbEgvS3DzPrT7TBpXBEkGwkXN/g+dG4UNmSDBTnZXKFKKSBKifjxkVEhy0Lj8Qlb1Dp3qBd25PZdcULVtv4HL7rCDTehMqKgAAAA", "_dd.appsec.waf.version": "1.29.0", "_dd.base_service": "", "_dd.origin": "appsec", diff --git a/tests/tracer/test_resource_renaming.py b/tests/tracer/test_resource_renaming.py index 840f5b76541..bcc908bd06d 100644 --- a/tests/tracer/test_resource_renaming.py +++ b/tests/tracer/test_resource_renaming.py @@ -1,6 +1,7 @@ import pytest from ddtrace._trace.processor.resource_renaming import ResourceRenamingProcessor +from ddtrace._trace.processor.resource_renaming import SimplifiedEndpointComputer from ddtrace.ext import SpanTypes from ddtrace.ext import http from ddtrace.trace import Context @@ -44,8 +45,7 @@ class TestResourceRenaming: ], ) def test_compute_simplified_endpoint_path_element(self, elem, expected): - processor = ResourceRenamingProcessor() - result = processor._compute_simplified_endpoint_path_element(elem) + result = SimplifiedEndpointComputer()._compute_simplified_endpoint_path_element(elem) assert result == expected @pytest.mark.parametrize( @@ -89,8 +89,7 @@ def test_compute_simplified_endpoint_path_element(self, elem, expected): ], ) def test_compute_simplified_endpoint(self, url, expected): - processor = ResourceRenamingProcessor() - result = processor._compute_simplified_endpoint(url) + result = SimplifiedEndpointComputer().from_url(url) assert result == expected def test_processor_with_route(self):