@@ -51,9 +51,14 @@ const (
5151 // Format for the "jwks_uri" endpoint
5252 // The %s refers to the tenant ID
5353 jwksURIFormat = "https://login.microsoftonline.com/%s/discovery/v2.0/keys"
54- // Format for the "iss" claim in the JWT
54+ // Format for the "iss" claim in the JWT (Azure AD v2.0)
5555 // The %s refers to the tenant ID
5656 jwtIssuerFormat = "https://login.microsoftonline.com/%s/v2.0"
57+ // Format for the "iss" claim in the JWT (Azure AD v1.0 - legacy)
58+ // The %s refers to the tenant ID
59+ jwtIssuerV1Format = "https://sts.windows.net/%s/"
60+ // Eventgrid managed identity issuer
61+ eventGridIssuer = "https://eventgrid.azure.net"
5762)
5863
5964// AzureEventGrid allows sending/receiving Azure Event Grid events.
@@ -111,34 +116,113 @@ func (a *AzureEventGrid) Init(_ context.Context, metadata bindings.Metadata) err
111116
112117var matchAuthHeader = regexp .MustCompile (`(?i)^(Bearer )?(([A-Za-z0-9_-]+\.){2}[A-Za-z0-9_-]+)$` )
113118
114- func (a * AzureEventGrid ) validateAuthHeader (ctx context. Context , authorizationHeader string ) bool {
119+ func (a * AzureEventGrid ) validateAuthHeader (ctx * fasthttp. RequestCtx ) bool {
115120 // Extract the bearer token from the header
121+ authorizationHeader := string (ctx .Request .Header .Peek ("authorization" ))
116122 if authorizationHeader == "" {
117123 a .logger .Error ("Incoming webhook request does not contain an Authorization header" )
118124 return false
119125 }
126+
120127 match := matchAuthHeader .FindStringSubmatch (authorizationHeader )
121128 if len (match ) < 3 {
122129 a .logger .Error ("Incoming webhook request does not contain a valid bearer token in the Authorization header" )
123130 return false
124131 }
125132 token := match [2 ]
126133
127- // Validate the JWT
128- _ , err := jwt .ParseString (
129- token ,
130- jwt .WithKeySet (a .jwks , jws .WithInferAlgorithmFromKey (true )),
131- jwt .WithAudience (a .metadata .azureClientID ),
132- jwt .WithIssuer (fmt .Sprintf (jwtIssuerFormat , a .metadata .azureTenantID )),
133- jwt .WithAcceptableSkew (5 * time .Minute ),
134- jwt .WithContext (ctx ),
135- )
134+ // First, parse the JWT to see what claims we received
135+ parsedToken , err := jwt .ParseString (token , jwt .WithVerify (false ))
136136 if err != nil {
137- a .logger .Errorf ("Failed to validate JWT in the incoming webhook request : %v" , err )
137+ a .logger .Errorf ("Failed to parse JWT: %v" , err )
138138 return false
139139 }
140140
141- return true
141+ actualIssuer := parsedToken .Issuer ()
142+ azureADV2Issuer := fmt .Sprintf (jwtIssuerFormat , a .metadata .azureTenantID )
143+ expectedAudience := a .metadata .azureClientID
144+ switch actualIssuer {
145+ case azureADV2Issuer :
146+ // AzureAD v2.0 issuer
147+ _ , err = jwt .ParseString (
148+ token ,
149+ jwt .WithKeySet (a .jwks , jws .WithInferAlgorithmFromKey (true )),
150+ jwt .WithAudience (expectedAudience ),
151+ jwt .WithIssuer (azureADV2Issuer ),
152+ jwt .WithAcceptableSkew (5 * time .Minute ),
153+ jwt .WithContext (context .Background ()),
154+ )
155+ if err == nil {
156+ return true
157+ }
158+
159+ // Also check webhook URL as audience
160+ _ , err = jwt .ParseString (
161+ token ,
162+ jwt .WithKeySet (a .jwks , jws .WithInferAlgorithmFromKey (true )),
163+ jwt .WithAudience (a .metadata .SubscriberEndpoint ),
164+ jwt .WithIssuer (azureADV2Issuer ),
165+ jwt .WithAcceptableSkew (5 * time .Minute ),
166+ jwt .WithContext (context .Background ()),
167+ )
168+ if err == nil {
169+ return true
170+ }
171+
172+ a .logger .Errorf ("JWT validation failed for AzureAD v2.0 issuer" )
173+ return false
174+
175+ case fmt .Sprintf (jwtIssuerV1Format , a .metadata .azureTenantID ):
176+ // AzureAD v1.0 issuer
177+ a .logger .Infof ("Detected AzureAD v1.0 issuer, validating..." )
178+ _ , err = jwt .ParseString (
179+ token ,
180+ jwt .WithKeySet (a .jwks , jws .WithInferAlgorithmFromKey (true )),
181+ jwt .WithAudience (expectedAudience ),
182+ jwt .WithIssuer (actualIssuer ),
183+ jwt .WithAcceptableSkew (5 * time .Minute ),
184+ jwt .WithContext (context .Background ()),
185+ )
186+ if err == nil {
187+ return true
188+ }
189+ _ , err = jwt .ParseString (
190+ token ,
191+ jwt .WithKeySet (a .jwks , jws .WithInferAlgorithmFromKey (true )),
192+ jwt .WithAudience (a .metadata .SubscriberEndpoint ),
193+ jwt .WithIssuer (actualIssuer ),
194+ jwt .WithAcceptableSkew (5 * time .Minute ),
195+ jwt .WithContext (context .Background ()),
196+ )
197+ if err == nil {
198+ return true
199+ }
200+
201+ a .logger .Errorf ("JWT validation failed for AzureAD v1.0 issuer" )
202+ return false
203+
204+ case eventGridIssuer :
205+ // eventgrid managed identity issuer - use webhook URL as audience
206+ _ , err = jwt .ParseString (
207+ token ,
208+ jwt .WithKeySet (a .jwks , jws .WithInferAlgorithmFromKey (true )),
209+ jwt .WithAudience (a .metadata .SubscriberEndpoint ),
210+ jwt .WithIssuer (eventGridIssuer ),
211+ jwt .WithAcceptableSkew (5 * time .Minute ),
212+ jwt .WithContext (context .Background ()),
213+ )
214+ if err == nil {
215+ return true
216+ }
217+
218+ a .logger .Errorf ("JWT validation failed for eventgrid issuer: %v" , err )
219+ return false
220+
221+ default :
222+ a .logger .Errorf ("Unexpected JWT issuer: %s. Expected either '%s', '%s', or '%s'" ,
223+ actualIssuer , azureADV2Issuer , fmt .Sprintf (jwtIssuerV1Format , a .metadata .azureTenantID ), eventGridIssuer )
224+ return false
225+ }
142226}
143227
144228// Initializes the JWKS cache
@@ -284,10 +368,12 @@ func (a *AzureEventGrid) requestHandler(handler bindings.Handler) fasthttp.Reque
284368 return
285369 }
286370
287- // Validate the Authorization header
288- authorizationHeader := string (ctx .Request .Header .Peek ("authorization" ))
289- // Note that ctx is a fasthttp context so it's actually tied to the server's lifecycle and not the request's
290- if ! a .validateAuthHeader (ctx , authorizationHeader ) {
371+ // Options requests (webhook validation handshake) don't require authentication
372+ // Azure Event Grid sends options requests without authorization header during initial validation
373+ if method == http .MethodOptions {
374+ // Skip authentication for options requests
375+ } else if ! a .validateAuthHeader (ctx ) {
376+ // Note that ctx is a fasthttp context so it's actually tied to the server's lifecycle and not the request's
291377 ctx .Response .Header .SetStatusCode (http .StatusUnauthorized )
292378 _ , err = ctx .Response .BodyWriter ().Write ([]byte ("401 Unauthorized" ))
293379 if err != nil {
0 commit comments