diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 2621de021c5..e1fd9d66b3b 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -108,11 +108,9 @@ func recoverFromPanic(c *gin.Context) { } // CustomRecoveryWithWriter returns a middleware for a writer that recovers from any panics and writes a 500 if there was one. -func CustomRecoveryWithWriter() gin.HandlerFunc { - return func(c *gin.Context) { - defer recoverFromPanic(c) - c.Next() - } +func CustomRecoveryWithWriter(c *gin.Context) { + defer recoverFromPanic(c) + c.Next() } // NewServer creates a LAPI server. @@ -179,7 +177,7 @@ func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg, accessLo router.NoRoute(func(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"message": "Page or Method not found"}) }) - router.Use(CustomRecoveryWithWriter()) + router.Use(CustomRecoveryWithWriter) controller := &controllers.Controller{ DBClient: dbClient, diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index ab8dc6501fb..f4d35e20df0 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -96,7 +96,7 @@ func (c *Controller) NewV1() error { } c.Router.GET("/health", gin.WrapF(serveHealth())) - c.Router.Use(v1.PrometheusMiddleware()) + c.Router.Use(v1.PrometheusMiddleware) // We don't want to compress the response body as it would likely break some existing bouncers // But we do want to automatically uncompress incoming requests c.Router.Use(gzip.Gzip(gzip.NoCompression, gzip.WithDecompressOnly(), gzip.WithDecompressFn(gzip.DefaultDecompressHandle))) @@ -116,7 +116,7 @@ func (c *Controller) NewV1() error { jwtAuth := groupV1.Group("") jwtAuth.GET("/refresh_token", c.HandlerV1.Middlewares.JWT.Middleware.RefreshHandler) - jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware()) + jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware) { jwtAuth.POST("/alerts", c.HandlerV1.CreateAlert) jwtAuth.GET("/alerts", c.HandlerV1.FindAlerts) @@ -137,7 +137,7 @@ func (c *Controller) NewV1() error { } apiKeyAuth := groupV1.Group("") - apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.MiddlewareFunc(), v1.PrometheusBouncersMiddleware()) + apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.Middleware, v1.PrometheusBouncersMiddleware) { apiKeyAuth.GET("/decisions", c.HandlerV1.GetDecision) apiKeyAuth.HEAD("/decisions", c.HandlerV1.GetDecision) @@ -146,7 +146,7 @@ func (c *Controller) NewV1() error { } eitherAuth := groupV1.Group("") - eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.MiddlewareFunc())) + eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.Middleware)) { eitherAuth.POST("/usage-metrics", c.HandlerV1.UsageMetrics) } diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 2c956b8e392..01a35fc199e 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -31,51 +31,45 @@ func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { }).Inc() } -func PrometheusMachinesMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - machineID, _ := getMachineIDFromContext(c) - if machineID == "" { - return - } - - metrics.LapiMachineHits.With(prometheus.Labels{ - "machine": machineID, - "route": cmp.Or(c.FullPath(), "invalid-endpoint"), - "method": c.Request.Method, - }).Inc() +func PrometheusMachinesMiddleware(c *gin.Context) { + machineID, _ := getMachineIDFromContext(c) + if machineID == "" { + return } -} -func PrometheusBouncersMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - bouncer, _ := getBouncerFromContext(c) - if bouncer == nil { - return - } + metrics.LapiMachineHits.With(prometheus.Labels{ + "machine": machineID, + "route": cmp.Or(c.FullPath(), "invalid-endpoint"), + "method": c.Request.Method, + }).Inc() +} - metrics.LapiBouncerHits.With(prometheus.Labels{ - "bouncer": bouncer.Name, - "route": cmp.Or(c.FullPath(), "invalid-endpoint"), - "method": c.Request.Method, - }).Inc() +func PrometheusBouncersMiddleware(c *gin.Context) { + bouncer, _ := getBouncerFromContext(c) + if bouncer == nil { + return } + + metrics.LapiBouncerHits.With(prometheus.Labels{ + "bouncer": bouncer.Name, + "route": cmp.Or(c.FullPath(), "invalid-endpoint"), + "method": c.Request.Method, + }).Inc() } -func PrometheusMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - startTime := time.Now() +func PrometheusMiddleware(c *gin.Context) { + startTime := time.Now() - metrics.LapiRouteHits.With(prometheus.Labels{ - "route": cmp.Or(c.FullPath(), "invalid-endpoint"), - "method": c.Request.Method, - }).Inc() - c.Next() + metrics.LapiRouteHits.With(prometheus.Labels{ + "route": cmp.Or(c.FullPath(), "invalid-endpoint"), + "method": c.Request.Method, + }).Inc() + c.Next() - elapsed := time.Since(startTime) - metrics.LapiResponseTime.With( - prometheus.Labels{ - "method": c.Request.Method, - "endpoint": c.FullPath(), - }).Observe(elapsed.Seconds()) - } + elapsed := time.Since(startTime) + metrics.LapiResponseTime.With( + prometheus.Labels{ + "method": c.Request.Method, + "endpoint": c.FullPath(), + }).Observe(elapsed.Seconds()) } diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 2f23da7bb62..5c9d696a431 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -216,65 +216,63 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { return bouncer } -func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { - return func(c *gin.Context) { - var bouncer *ent.Bouncer +func (a *APIKey) Middleware(c *gin.Context) { + var bouncer *ent.Bouncer - ctx := c.Request.Context() + ctx := c.Request.Context() - clientIP := c.ClientIP() + clientIP := c.ClientIP() - logger := log.WithField("ip", clientIP) + logger := log.WithField("ip", clientIP) - if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { - bouncer = a.authTLS(c, logger) - } else { - bouncer = a.authPlain(c, logger) - } + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { + bouncer = a.authTLS(c, logger) + } else { + bouncer = a.authPlain(c, logger) + } - if bouncer == nil { - // XXX: StatusUnauthorized? - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() + if bouncer == nil { + // XXX: StatusUnauthorized? + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() - return - } + return + } - // Appsec request, return immediately if we found something - if c.Request.Method == http.MethodHead { - c.Set(BouncerContextKey, bouncer) - return - } + // Appsec request, return immediately if we found something + if c.Request.Method == http.MethodHead { + c.Set(BouncerContextKey, bouncer) + return + } - logger = logger.WithField("name", bouncer.Name) + logger = logger.WithField("name", bouncer.Name) - // 1st time we see this bouncer, we update its IP - if bouncer.IPAddress == "" { - if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { - logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() + // 1st time we see this bouncer, we update its IP + if bouncer.IPAddress == "" { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { + logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() - return - } + return } + } - useragent := strings.Split(c.Request.UserAgent(), "/") - if len(useragent) != 2 { - logger.Warningf("bad user agent '%s'", c.Request.UserAgent()) - useragent = []string{c.Request.UserAgent(), "N/A"} - } + useragent := strings.Split(c.Request.UserAgent(), "/") + if len(useragent) != 2 { + logger.Warningf("bad user agent '%s'", c.Request.UserAgent()) + useragent = []string{c.Request.UserAgent(), "N/A"} + } - if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { - if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { - logger.Errorf("failed to update bouncer version and type: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) - c.Abort() + if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { + if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { + logger.Errorf("failed to update bouncer version and type: %s", err) + c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) + c.Abort() - return - } + return } - - c.Set(BouncerContextKey, bouncer) } + + c.Set(BouncerContextKey, bouncer) }