1- import asyncio
2- import json
31import logging
4- from dataclasses import dataclass
2+ import os
3+ import sys
4+ from dataclasses import dataclass , asdict
55from typing import Final , Optional
66
77from middleware .request_context import RequestContext , get_request_context
8+ from shared .common .logging_utils import get_env_logging_level
89from utils .config import get_config , PROJECT_ID
910from utils .dict_utils import get_safe_value
1011
12+ import google .cloud .logging
13+ from google .cloud .logging_v2 .handlers import CloudLoggingFilter , CloudLoggingHandler
14+
1115API_ACCESS_LOG : Final [str ] = "api-access-log"
1216CLOUD_RUN_SERVICE_ID : Final [str ] = "K_SERVICE"
1317CLOUD_RUN_REVISION_ID : Final [str ] = "K_REVISION"
@@ -27,101 +31,125 @@ class HttpRequest:
2731 userAgent : str
2832 remoteIp : str
2933 serverIp : str
30- latency : float
34+ latency : str
3135 protocol : str
3236
3337
34- @dataclass
35- class LogRecord :
38+ def get_trace (request_context : RequestContext ):
3639 """
37- Data class for Log Record
40+ Get the trace id from the log record
3841 """
42+ trace = ""
43+ trace_id = get_safe_value (request_context , "trace_id" )
44+ if trace_id :
45+ trace = f"projects/{ get_config (PROJECT_ID , '' )} /traces/{ trace_id } "
46+ return trace
3947
40- user_id : str
41- httpRequest : dict
42- trace : str
43- spanId : str
44- traceSampled : bool
45- textPayload : Optional [str ]
46- jsonPayload : Optional [dict ]
48+
49+ def get_http_request (record ) -> HttpRequest | None :
50+ """
51+ Get the http request from the log record
52+ If the http request is not found, return None
53+ """
54+ context = record .__getattribute__ ("context" ) if hasattr (record , "context" ) else None
55+ return context .get ("http_request" ) if context else {}
4756
4857
49- class AsyncStreamHandler ( logging . StreamHandler ):
58+ class GoogleCloudLogFilter ( CloudLoggingFilter ):
5059 """
51- Async Stream Handler
60+ Log filter for Google Cloud Logging.
61+ This filter adds the trace, span and http_request fields to the log record.
5262 """
5363
54- def __init__ (self , * args , ** kwargs ):
55- super ().__init__ (* args , ** kwargs )
56- self .loop = asyncio .get_event_loop ()
64+ def filter (self , record : logging .LogRecord ) -> bool :
65+ request_context = get_request_context ()
66+ http_request = get_http_request (record )
67+ if http_request :
68+ record .http_request = asdict (http_request )
69+ span_id = request_context .get ("span_id" )
70+ trace = get_trace (request_context )
71+ record .trace = trace
72+ record .span_id = span_id
73+
74+ record ._log_fields = {
75+ "logging.googleapis.com/trace" : trace ,
76+ "logging.googleapis.com/spanId" : span_id ,
77+ "logging.googleapis.com/httpRequest" : asdict (http_request ) if http_request else None ,
78+ "logging.googleapis.com/trace_sampled" : True ,
79+ }
80+ super ().filter (record )
81+
82+ return True
83+
84+
85+ class StderrToLog :
86+ """
87+ Redirect stderr to log
88+ """
5789
58- def emit (self , record ):
59- """
60- Emit the log record
61- """
62- asyncio .ensure_future (self .async_emit (record ))
90+ def __init__ (self , logger ):
91+ self .logger = logger
6392
64- async def async_emit (self , record ):
65- """
66- Async emit the log record
67- """
68- msg = self .format (record )
69- stream = self .stream
70- await self .loop .run_in_executor (None , stream .write , msg )
71- await self .loop .run_in_executor (None , stream .flush )
93+ def write (self , message ):
94+ message = message .strip ()
95+ if message :
96+ self .logger .error (message )
97+
98+ def flush (self ):
99+ pass
72100
73101
74- class GCPLogHandler ( AsyncStreamHandler ):
102+ def get_logger ( name : Optional [ str ] ):
75103 """
76- GCP Log Handler
104+ Returns a logger with the name making sure the propagate flag is set to True.
77105 """
106+ logger = logging .getLogger (name )
107+ logger .propagate = True
108+ return logger
78109
79- def __init__ (self ):
80- console_handler = logging .StreamHandler ()
81- self .logger = logging .getLogger ()
82- self .logger .addHandler (console_handler )
83- self .logger .setLevel (logging .DEBUG )
84- super ().__init__ ()
85-
86- @staticmethod
87- def get_trace (request_context : RequestContext ):
88- """
89- Get the trace id from the log record
90- """
91- trace = ""
92- trace_id = get_safe_value (request_context , "trace_id" )
93- if trace_id :
94- trace = f"projects/{ get_config (PROJECT_ID , '' )} /traces/{ trace_id } "
95- return trace
96-
97- @staticmethod
98- def get_http_request (record ) -> HttpRequest :
99- context = record .__getattribute__ ("context" ) if hasattr (record , "context" ) else None
100- return context .get ("http_request" ) if context else {}
101-
102- async def async_emit (self , record ):
103- """
104- Emit the GCP log record
105- """
106- http_request = self .get_http_request (record )
107- request_context = get_request_context ()
108- text_payload = None
109- json_payload = None
110- message = record .msg if hasattr (record , "msg" ) else None
111- message = record .getMessage () if message is None and hasattr (record , "getMessage" ) else message
112- if message :
113- if type (message ) is dict :
114- json_payload = message
115- else :
116- text_payload = str (message )
117-
118- log_record : LogRecord = LogRecord (
119- httpRequest = http_request .__dict__ if not isinstance (http_request , dict ) else {},
120- trace = self .get_trace (request_context ),
121- spanId = request_context .get ("span_id" ),
122- traceSampled = request_context .get ("trace_sampled" ),
123- user_id = request_context .get ("user_id" ),
124- textPayload = text_payload ,
125- jsonPayload = json_payload ,
126- )
127- self .logger .info (json .dumps (log_record .__dict__ ))
110+
111+ def is_local_env ():
112+ """
113+ Returns: True if the environment is local, False otherwise
114+ """
115+ return os .getenv ("K_SERVICE" ) is None
116+
117+
118+ def global_logging_setup ():
119+ if is_local_env ():
120+ logging .basicConfig (level = get_env_logging_level ())
121+ return
122+
123+ # Send warnings through logging
124+ logging .captureWarnings (True )
125+ # Replace sys.stderr
126+ sys .stderr = StderrToLog (logging .getLogger ("stderr" ))
127+ try :
128+ client = google .cloud .logging .Client ()
129+ handler = CloudLoggingHandler (client , structured = True )
130+ handler .setLevel (get_env_logging_level ())
131+ handler .addFilter (GoogleCloudLogFilter (project = client .project ))
132+ except Exception as e :
133+ logging .error ("Error initializing cloud logging: %s" , e )
134+ logging .basicConfig (level = get_env_logging_level ())
135+ return
136+
137+ # Configure root logger
138+ root_logger = logging .getLogger ()
139+ root_logger .setLevel (get_env_logging_level ())
140+ root_logger .handlers .clear ()
141+ root_logger .addHandler (handler )
142+
143+ # This overrides individual logs essential for debugging purposes.
144+ for name in [
145+ "sqlalchemy" ,
146+ "uvicorn" ,
147+ "uvicorn.error" ,
148+ "uvicorn.access" ,
149+ "sqlalchemy.exc" ,
150+ "feed-api" ,
151+ "sqlalchemy.engine" ,
152+ ]:
153+ get_logger (name )
154+
155+ logging .info ("Setting cloud up logging completed" )
0 commit comments