3131
3232from __future__ import absolute_import
3333
34+ import asyncio
35+ import functools
3436from typing import Dict , Optional
3537
3638import starlette
37- from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
3839from starlette .requests import Request
39- from starlette .responses import Response
4040from starlette .routing import Match , Mount
41- from starlette .types import ASGIApp
41+ from starlette .types import ASGIApp , Message
4242
4343import elasticapm
4444import elasticapm .instrumentation .control
@@ -70,32 +70,28 @@ def make_apm_client(config: Optional[Dict] = None, client_cls=Client, **defaults
7070 return client_cls (config , ** defaults )
7171
7272
73- class ElasticAPM ( BaseHTTPMiddleware ) :
73+ class ElasticAPM :
7474 """
7575 Starlette / FastAPI middleware for Elastic APM capturing.
7676
77- >>> elasticapm = make_apm_client({
77+ >>> apm = make_apm_client({
7878 >>> 'SERVICE_NAME': 'myapp',
7979 >>> 'DEBUG': True,
8080 >>> 'SERVER_URL': 'http://localhost:8200',
8181 >>> 'CAPTURE_HEADERS': True,
8282 >>> 'CAPTURE_BODY': 'all'
8383 >>> })
8484
85- >>> app.add_middleware(ElasticAPM, client=elasticapm )
85+ >>> app.add_middleware(ElasticAPM, client=apm )
8686
8787 Pass an arbitrary APP_NAME and SECRET_TOKEN::
8888
8989 >>> elasticapm = ElasticAPM(app, service_name='myapp', secret_token='asdasdasd')
9090
91- Pass an explicit client::
91+ Pass an explicit client (don't pass in additional options in this case) ::
9292
9393 >>> elasticapm = ElasticAPM(app, client=client)
9494
95- Automatically configure logging::
96-
97- >>> elasticapm = ElasticAPM(app, logging=True)
98-
9995 Capture an exception::
10096
10197 >>> try:
@@ -108,34 +104,69 @@ class ElasticAPM(BaseHTTPMiddleware):
108104 >>> elasticapm.capture_message('hello, world!')
109105 """
110106
111- def __init__ (self , app : ASGIApp , client : Client ):
107+ def __init__ (self , app : ASGIApp , client : Optional [ Client ], ** kwargs ):
112108 """
113109
114110 Args:
115111 app (ASGIApp): Starlette app
116112 client (Client): ElasticAPM Client
117113 """
118- self .client = client
114+ if client :
115+ self .client = client
116+ else :
117+ self .client = make_apm_client (** kwargs )
119118
120119 if self .client .config .instrument and self .client .config .enabled :
121120 elasticapm .instrumentation .control .instrument ()
122121
123- super ().__init__ (app )
124-
125- async def dispatch (self , request : Request , call_next : RequestResponseEndpoint ) -> Response :
126- """Processes the whole request APM capturing.
122+ # If we ever make this a general-use ASGI middleware we should use
123+ # `asgiref.conpatibility.guarantee_single_callable(app)` here
124+ self .app = app
127125
126+ async def __call__ (self , scope , receive , send ):
127+ """
128128 Args:
129- request (Request)
130- call_next (RequestResponseEndpoint): Next request process in Starlette.
131-
132- Returns:
133- Response
129+ scope: ASGI scope dictionary
130+ receive: receive awaitable callable
131+ send: send awaitable callable
134132 """
133+
134+ @functools .wraps (send )
135+ async def wrapped_send (message ):
136+ if message .get ("type" ) == "http.response.start" :
137+ await set_context (
138+ lambda : get_data_from_response (message , self .client .config , constants .TRANSACTION ), "response"
139+ )
140+ result = "HTTP {}xx" .format (message ["status" ] // 100 )
141+ elasticapm .set_transaction_result (result , override = False )
142+ await send (message )
143+
144+ # When we consume the body from receive, we replace the streaming
145+ # mechanism with a mocked version -- this workaround came from
146+ # https://github.com/encode/starlette/issues/495#issuecomment-513138055
147+ body = b""
148+ while True :
149+ message = await receive ()
150+ if not message :
151+ break
152+ if message ["type" ] == "http.request" :
153+ b = message .get ("body" , b"" )
154+ if b :
155+ body += b
156+ if not message .get ("more_body" , False ):
157+ break
158+ if message ["type" ] == "http.disconnect" :
159+ break
160+
161+ async def _receive () -> Message :
162+ await asyncio .sleep (0 )
163+ return {"type" : "http.request" , "body" : body }
164+
165+ request = Request (scope , receive = _receive )
135166 await self ._request_started (request )
136167
137168 try :
138- response = await call_next ( request )
169+ await self . app ( scope , _receive , wrapped_send )
139170 elasticapm .set_transaction_outcome (constants .OUTCOME .SUCCESS , override = False )
140171 except Exception :
141172 await self .capture_exception (
@@ -146,13 +177,9 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
146177 elasticapm .set_context ({"status_code" : 500 }, "response" )
147178
148179 raise
149- else :
150- await self ._request_finished (response )
151180 finally :
152181 self .client .end_transaction ()
153182
154- return response
155-
156183 async def capture_exception (self , * args , ** kwargs ):
157184 """Captures your exception.
158185
@@ -195,19 +222,6 @@ async def _request_started(self, request: Request):
195222 transaction_name = self .get_route_name (request ) or request .url .path
196223 elasticapm .set_transaction_name ("{} {}" .format (request .method , transaction_name ), override = False )
197224
198- async def _request_finished (self , response : Response ):
199- """Captures the end of the request processing to APM.
200-
201- Args:
202- response (Response)
203- """
204- await set_context (
205- lambda : get_data_from_response (response , self .client .config , constants .TRANSACTION ), "response"
206- )
207-
208- result = "HTTP {}xx" .format (response .status_code // 100 )
209- elasticapm .set_transaction_result (result , override = False )
210-
211225 def get_route_name (self , request : Request ) -> str :
212226 app = request .app
213227 scope = request .scope
0 commit comments