9393import zlib
9494from abc import ABC
9595from base64 import b64encode
96+ from collections .abc import Callable , Mapping
9697from functools import cached_property
97- from typing import Callable , List , Mapping , Tuple , Type , Union
9898
9999import falcon .asgi
100100import msgspec
@@ -172,24 +172,42 @@ async def func_wrapper(*args, **kwargs):
172172 return func_wrapper
173173
174174
175+ DEFAULT_META_HEADERS = frozenset (
176+ [
177+ "url" ,
178+ "remote_addr" ,
179+ "user-agent" ,
180+ ]
181+ )
182+
183+
175184def add_metadata (func : Callable ):
176185 """Decorator to add metadata to resulting http event.
177186 Uses attribute collect_meta of endpoint class to decide over metadata collection
178187 Uses attribute metafield_name to define key name for metadata
179188 """
180189
181190 async def func_wrapper (* args , ** kwargs ):
182- req = args [1 ]
183- endpoint = args [0 ]
184- if endpoint .collect_meta :
185- metadata = {
186- "url" : req .url ,
187- "remote_addr" : req .remote_addr ,
188- "user_agent" : req .user_agent ,
189- }
190- kwargs ["metadata" ] = {endpoint .metafield_name : metadata }
191- else :
191+ req : falcon .Request = args [1 ]
192+ endpoint : HttpEndpoint = args [0 ]
193+
194+ if not endpoint .collect_meta or len (endpoint .copy_headers_to_logs ) == 0 :
192195 kwargs ["metadata" ] = {}
196+ else :
197+ metadata = {}
198+ for header in endpoint .copy_headers_to_logs :
199+ # remote_addr and url are special cases, because those are not copied 1 to 1 from headers
200+ match header :
201+ case "remote_addr" :
202+ metadata [header ] = req .remote_addr
203+ case "url" :
204+ metadata [header ] = req .url
205+ case _:
206+ key = header .replace ("-" , "_" ).lower ()
207+ metadata [key ] = req .get_header (header , required = False , default = None )
208+
209+ kwargs ["metadata" ] = {endpoint .metafield_name : metadata }
210+
193211 func_wrapper = await func (* args , ** kwargs )
194212 return func_wrapper
195213
@@ -231,9 +249,13 @@ def __init__(
231249 metafield_name : str ,
232250 credentials : Credentials ,
233251 metrics : "HttpInput.Metrics" ,
252+ copy_headers_to_logs : set [str ],
234253 ) -> None :
235254 self .messages = messages
236255 self .original_event_field = original_event_field
256+ self .copy_headers_to_logs = copy_headers_to_logs
257+
258+ # Deprecated
237259 self .collect_meta = collect_meta
238260 self .metafield_name = metafield_name
239261 self .credentials = credentials
@@ -271,6 +293,12 @@ async def get_data(self, req: falcon.Request) -> bytes:
271293 data = zlib .decompress (data , 31 )
272294 return data
273295
296+ def put_message (self , event : dict , metadata : dict ):
297+ """Puts message to internal queue"""
298+ if self .metafield_name in event :
299+ logger .warning ("metadata field was in event and got overwritten" )
300+ self .messages .put (event | metadata , block = False )
301+
274302
275303class JSONHttpEndpoint (HttpEndpoint ):
276304 """:code:`json` endpoint to get json from request"""
@@ -293,7 +321,7 @@ async def __call__(self, req, resp, **kwargs): # pylint: disable=arguments-diff
293321 )
294322 event = {}
295323 add_fields_to (event , {target_field : event_value })
296- self .messages . put (event | kwargs ["metadata" ], block = False )
324+ self .put_message (event , kwargs ["metadata" ])
297325
298326
299327class JSONLHttpEndpoint (HttpEndpoint ):
@@ -317,7 +345,8 @@ async def __call__(self, req, resp, **kwargs): # pylint: disable=arguments-diff
317345 )
318346 event = {}
319347 add_fields_to (event , {target_field : event_value })
320- self .messages .put (event | kwargs ["metadata" ], block = False , batch_size = len (events ))
348+
349+ self .put_message (event , kwargs ["metadata" ])
321350
322351
323352class PlaintextHttpEndpoint (HttpEndpoint ):
@@ -339,7 +368,7 @@ async def __call__(self, req, resp, **kwargs): # pylint: disable=arguments-diff
339368 )
340369 event = {}
341370 add_fields_to (event , {target_field : event_value })
342- self .messages . put (event | kwargs ["metadata" ], block = False )
371+ self .put_message (event , kwargs ["metadata" ])
343372
344373
345374class HttpInput (Input ):
@@ -369,7 +398,7 @@ class Metrics(Input.Metrics):
369398 class Config (Input .Config ):
370399 """Config for HTTPInput"""
371400
372- uvicorn_config : Mapping [str , Union [ str , int ] ] = field (
401+ uvicorn_config : Mapping [str , str | int ] = field (
373402 validator = [
374403 validators .instance_of (dict ),
375404 validators .deep_mapping (
@@ -432,8 +461,32 @@ class Config(Input.Config):
432461 be smaller than default value of 15.000 messages.
433462 """
434463
435- collect_meta : str = field (validator = validators .instance_of (bool ), default = True )
436- """Defines if metadata should be collected
464+ copy_headers_to_logs : set [str ] = field (
465+ validator = validators .deep_iterable (
466+ member_validator = validators .instance_of (str ),
467+ iterable_validator = validators .or_ (
468+ validators .instance_of (set ), validators .instance_of (list )
469+ ),
470+ ),
471+ converter = set ,
472+ factory = lambda : set (DEFAULT_META_HEADERS ),
473+ )
474+ """Defines what metadata should be collected from Http Headers
475+ Special cases:
476+ - remote_addr (Gets the inbound client ip instead of header)
477+ - url (Get the requested url from http request and not technically a header)
478+
479+ Defaults:
480+ - remote_addr
481+ - url
482+ - User-Agent
483+
484+ The output header names in Events are stored as json strings, and are transformed from "User-Agent" to "user_agent"
485+ """
486+
487+ collect_meta : bool = field (validator = validators .instance_of (bool ), default = True )
488+ """Deprecated use copy_headers_to_logs instead, to turn off collecting metadata set copy_headers_to_logs to an empty list ([]).
489+ Defines if metadata should be collected
437490 - :code:`True`: Collect metadata
438491 - :code:`False`: Won't collect metadata
439492
@@ -445,11 +498,15 @@ class Config(Input.Config):
445498 """
446499
447500 metafield_name : str = field (validator = validators .instance_of (str ), default = "@metadata" )
448- """Defines the name of the key for the collected metadata fields"""
501+ """Defines the name of the key for the collected metadata fields
502+ Logs a Warning if metadata field overwrites preexisting field in Event
503+ """
449504
450505 original_event_field : dict = field (
451506 validator = [
507+ # type: ignore
452508 validators .optional (
509+ # type: ignore
453510 validators .deep_mapping (
454511 key_validator = validators .in_ (["format" , "target_field" ]),
455512 value_validator = validators .instance_of (str ),
@@ -469,11 +526,11 @@ def __attrs_post_init__(self):
469526 "Cannot configure both add_full_event_to_target_field and original_event_field."
470527 )
471528
472- __slots__ : List [str ] = ["target" , "app" , "http_server" ]
529+ __slots__ : list [str ] = ["target" , "app" , "http_server" ]
473530
474531 messages : mp .Queue = None
475532
476- _endpoint_registry : Mapping [str , Type [HttpEndpoint ]] = {
533+ _endpoint_registry : Mapping [str , type [HttpEndpoint ]] = {
477534 "json" : JSONHttpEndpoint ,
478535 "plaintext" : PlaintextHttpEndpoint ,
479536 "jsonl" : JSONLHttpEndpoint ,
@@ -506,6 +563,7 @@ def setup(self):
506563
507564 endpoints_config = {}
508565 collect_meta = self ._config .collect_meta
566+ copy_headers_to_logs = self ._config .copy_headers_to_logs
509567 metafield_name = self ._config .metafield_name
510568 original_event_field = self ._config .original_event_field
511569 cred_factory = CredentialsFactory ()
@@ -521,6 +579,7 @@ def setup(self):
521579 metafield_name ,
522580 credentials ,
523581 self .metrics ,
582+ copy_headers_to_logs ,
524583 )
525584
526585 self .app = self ._get_asgi_app (endpoints_config )
@@ -537,7 +596,7 @@ def _get_asgi_app(endpoints_config: dict) -> falcon.asgi.App:
537596 app .add_sink (endpoint , prefix = route_compile_helper (endpoint_path ))
538597 return app
539598
540- def _get_event (self , timeout : float ) -> Tuple :
599+ def _get_event (self , timeout : float ) -> tuple :
541600 """Returns the first message from the queue"""
542601 self .metrics .message_backlog_size += self .messages .qsize ()
543602 try :
@@ -554,7 +613,7 @@ def shut_down(self):
554613 self .http_server .shut_down ()
555614
556615 @cached_property
557- def health_endpoints (self ) -> List [str ]:
616+ def health_endpoints (self ) -> list [str ]:
558617 """Returns a list of endpoints for internal healthcheck
559618 the endpoints are examples to match against the configured regex enabled
560619 endpoints. The endpoints are normalized to match the regex patterns and
0 commit comments