11import re
2- from typing import Any , Literal , Union , cast
2+ from typing import Any , Literal , TypedDict , Union , cast
33
44from django .core .exceptions import ValidationError as DjangoValidationError
55from django .http .response import HttpResponse
2727from social_django .models import UserSocialAuth
2828from social_django .utils import load_backend , load_strategy
2929
30+ from posthog .cloud_utils import get_cached_instance_license
3031from posthog .constants import AvailableFeature
32+ from posthog .exceptions_capture import capture_exception
3133from posthog .models .organization import OrganizationMembership
3234from posthog .models .organization_domain import OrganizationDomain
3335
@@ -258,6 +260,15 @@ def get_user_id(self, details, response):
258260logger = structlog .get_logger (__name__ )
259261
260262
263+ def _get_bearer_token (request : Request ) -> str | None :
264+ """Extract bearer token from Authorization header."""
265+ if auth_header := request .headers .get ("authorization" ):
266+ parts = auth_header .split ()
267+ if len (parts ) == 2 and parts [0 ].lower () == "bearer" :
268+ return parts [1 ]
269+ return None
270+
271+
261272class VercelAuthentication (authentication .BaseAuthentication ):
262273 """
263274 Implements Vercel Marketplace API authentication.
@@ -277,7 +288,7 @@ class VercelAuthentication(authentication.BaseAuthentication):
277288 VERCEL_ISSUER = "https://marketplace.vercel.com"
278289
279290 def authenticate (self , request : Request ) -> tuple [VercelUser , None ] | None :
280- token = self . _get_bearer_token (request )
291+ token = _get_bearer_token (request )
281292 if not token :
282293 raise AuthenticationFailed ("Missing Token for Vercel request" )
283294
@@ -293,14 +304,6 @@ def authenticate(self, request: Request) -> tuple[VercelUser, None] | None:
293304 logger .exception ("Vercel auth error" , auth_type = auth_type , error = str (e ), integration = "vercel" )
294305 raise AuthenticationFailed (f"{ auth_type .title ()} authentication failed" )
295306
296- def _get_bearer_token (self , request : Request ) -> str | None :
297- if auth_header := request .headers .get ("authorization" ):
298- parts = auth_header .split ()
299- if len (parts ) == 2 and parts [0 ].lower () == "bearer" :
300- return parts [1 ]
301-
302- return None
303-
304307 def _get_vercel_auth_type (self , request : Request ) -> "VercelAuthentication.VercelAuthType" :
305308 auth_type = request .headers .get ("X-Vercel-Auth" , "" ).lower ()
306309
@@ -400,6 +403,87 @@ def _decode_token(self, token: str) -> dict[str, Any]:
400403 )
401404
402405
406+ class BillingServiceJWTPayload (TypedDict ):
407+ organization_id : str
408+ aud : str
409+ exp : int
410+
411+
412+ class BillingServiceUser :
413+ """
414+ Represents an authenticated billing service request.
415+ Contains the organization_id from the validated JWT.
416+ """
417+
418+ def __init__ (self , organization_id : str ):
419+ self .organization_id = organization_id
420+
421+ @property
422+ def is_authenticated (self ) -> bool :
423+ return True
424+
425+
426+ class BillingServiceAuthentication (authentication .BaseAuthentication ):
427+ """
428+ Authenticates requests from the billing service to PostHog.
429+
430+ The billing service signs JWTs using the shared license secret (same secret PostHog
431+ uses when calling the billing service, but in reverse direction).
432+ """
433+
434+ EXPECTED_AUDIENCE = "billing:posthog-proxy"
435+
436+ def authenticate (self , request : Request ) -> tuple [BillingServiceUser , None ] | None :
437+ token = _get_bearer_token (request )
438+ if not token :
439+ raise AuthenticationFailed ("Missing authorization token" )
440+
441+ try :
442+ payload = self ._validate_jwt_token (token )
443+ except jwt .ExpiredSignatureError as e :
444+ capture_exception (e )
445+ logger .warning ("Billing service token expired" )
446+ raise AuthenticationFailed ("Token has expired" )
447+ except jwt .InvalidAudienceError as e :
448+ capture_exception (e )
449+ logger .warning ("Billing service token has invalid audience" )
450+ raise AuthenticationFailed ("Invalid token audience" )
451+ except jwt .InvalidTokenError as e :
452+ capture_exception (e )
453+ logger .exception ("Billing service auth failed" , error = str (e ))
454+ raise AuthenticationFailed ("Invalid authentication token" )
455+
456+ organization_id = payload .get ("organization_id" )
457+ if not organization_id :
458+ capture_exception (ValueError ("Billing service token missing organization_id" ))
459+ logger .warning ("Billing service token missing organization_id" )
460+ raise AuthenticationFailed ("Missing organization_id in token" )
461+
462+ return BillingServiceUser (organization_id = organization_id ), None
463+
464+ def _validate_jwt_token (self , token : str ) -> BillingServiceJWTPayload :
465+ license = get_cached_instance_license ()
466+ if not license or not license .key :
467+ capture_exception (ValueError ("Billing service auth failed: no license configured" ))
468+ logger .error ("Billing service auth failed: no license configured" )
469+ raise AuthenticationFailed ("No license configured" )
470+
471+ # Extract the secret from the license key (format: "id::secret")
472+ try :
473+ license_secret = license .key .split ("::" )[1 ]
474+ except IndexError :
475+ capture_exception (ValueError ("Billing service auth failed: invalid license key format" ))
476+ logger .exception ("Billing service auth failed: invalid license key format" )
477+ raise AuthenticationFailed ("Invalid license key format" )
478+
479+ return jwt .decode (
480+ token ,
481+ license_secret ,
482+ algorithms = ["HS256" ],
483+ audience = self .EXPECTED_AUDIENCE ,
484+ )
485+
486+
403487def social_auth_allowed (backend , details , response , * args , ** kwargs ) -> None :
404488 email = details .get ("email" )
405489 # Check if SSO enforcement is enabled for this email address
0 commit comments