Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,9 @@ func recoverFromPanic(c *gin.Context) {
}

// CustomRecoveryWithWriter returns a middleware for a writer that recovers from any panics and writes a 500 if there was one.
func CustomRecoveryWithWriter() gin.HandlerFunc {
return func(c *gin.Context) {
defer recoverFromPanic(c)
c.Next()
}
func CustomRecoveryWithWriter(c *gin.Context) {
defer recoverFromPanic(c)
c.Next()
}

// NewServer creates a LAPI server.
Expand Down Expand Up @@ -179,7 +177,7 @@ func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg, accessLo
router.NoRoute(func(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"message": "Page or Method not found"})
})
router.Use(CustomRecoveryWithWriter())
router.Use(CustomRecoveryWithWriter)

controller := &controllers.Controller{
DBClient: dbClient,
Expand Down
8 changes: 4 additions & 4 deletions pkg/apiserver/controllers/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (c *Controller) NewV1() error {
}

c.Router.GET("/health", gin.WrapF(serveHealth()))
c.Router.Use(v1.PrometheusMiddleware())
c.Router.Use(v1.PrometheusMiddleware)
// We don't want to compress the response body as it would likely break some existing bouncers
// But we do want to automatically uncompress incoming requests
c.Router.Use(gzip.Gzip(gzip.NoCompression, gzip.WithDecompressOnly(), gzip.WithDecompressFn(gzip.DefaultDecompressHandle)))
Expand All @@ -116,7 +116,7 @@ func (c *Controller) NewV1() error {

jwtAuth := groupV1.Group("")
jwtAuth.GET("/refresh_token", c.HandlerV1.Middlewares.JWT.Middleware.RefreshHandler)
jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware())
jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware)
{
jwtAuth.POST("/alerts", c.HandlerV1.CreateAlert)
jwtAuth.GET("/alerts", c.HandlerV1.FindAlerts)
Expand All @@ -137,7 +137,7 @@ func (c *Controller) NewV1() error {
}

apiKeyAuth := groupV1.Group("")
apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.MiddlewareFunc(), v1.PrometheusBouncersMiddleware())
apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.Middleware, v1.PrometheusBouncersMiddleware)
{
apiKeyAuth.GET("/decisions", c.HandlerV1.GetDecision)
apiKeyAuth.HEAD("/decisions", c.HandlerV1.GetDecision)
Expand All @@ -146,7 +146,7 @@ func (c *Controller) NewV1() error {
}

eitherAuth := groupV1.Group("")
eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.MiddlewareFunc()))
eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.Middleware))
{
eitherAuth.POST("/usage-metrics", c.HandlerV1.UsageMetrics)
}
Expand Down
72 changes: 33 additions & 39 deletions pkg/apiserver/controllers/v1/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,51 +31,45 @@ func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) {
}).Inc()
}

func PrometheusMachinesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
machineID, _ := getMachineIDFromContext(c)
if machineID == "" {
return
}

metrics.LapiMachineHits.With(prometheus.Labels{
"machine": machineID,
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
"method": c.Request.Method,
}).Inc()
func PrometheusMachinesMiddleware(c *gin.Context) {
machineID, _ := getMachineIDFromContext(c)
if machineID == "" {
return
}
}

func PrometheusBouncersMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
bouncer, _ := getBouncerFromContext(c)
if bouncer == nil {
return
}
metrics.LapiMachineHits.With(prometheus.Labels{
"machine": machineID,
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
"method": c.Request.Method,
}).Inc()
}

metrics.LapiBouncerHits.With(prometheus.Labels{
"bouncer": bouncer.Name,
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
"method": c.Request.Method,
}).Inc()
func PrometheusBouncersMiddleware(c *gin.Context) {
bouncer, _ := getBouncerFromContext(c)
if bouncer == nil {
return
}

metrics.LapiBouncerHits.With(prometheus.Labels{
"bouncer": bouncer.Name,
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
"method": c.Request.Method,
}).Inc()
}

func PrometheusMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
startTime := time.Now()
func PrometheusMiddleware(c *gin.Context) {
startTime := time.Now()

metrics.LapiRouteHits.With(prometheus.Labels{
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
"method": c.Request.Method,
}).Inc()
c.Next()
metrics.LapiRouteHits.With(prometheus.Labels{
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
"method": c.Request.Method,
}).Inc()
c.Next()

elapsed := time.Since(startTime)
metrics.LapiResponseTime.With(
prometheus.Labels{
"method": c.Request.Method,
"endpoint": c.FullPath(),
}).Observe(elapsed.Seconds())
}
elapsed := time.Since(startTime)
metrics.LapiResponseTime.With(
prometheus.Labels{
"method": c.Request.Method,
"endpoint": c.FullPath(),
}).Observe(elapsed.Seconds())
}
88 changes: 43 additions & 45 deletions pkg/apiserver/middlewares/v1/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,65 +216,63 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer {
return bouncer
}

func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
return func(c *gin.Context) {
var bouncer *ent.Bouncer
func (a *APIKey) Middleware(c *gin.Context) {
var bouncer *ent.Bouncer

ctx := c.Request.Context()
ctx := c.Request.Context()

clientIP := c.ClientIP()
clientIP := c.ClientIP()

logger := log.WithField("ip", clientIP)
logger := log.WithField("ip", clientIP)

if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
bouncer = a.authTLS(c, logger)
} else {
bouncer = a.authPlain(c, logger)
}
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
bouncer = a.authTLS(c, logger)
} else {
bouncer = a.authPlain(c, logger)
}

if bouncer == nil {
// XXX: StatusUnauthorized?
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
if bouncer == nil {
// XXX: StatusUnauthorized?
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()

return
}
return
}

// Appsec request, return immediately if we found something
if c.Request.Method == http.MethodHead {
c.Set(BouncerContextKey, bouncer)
return
}
// Appsec request, return immediately if we found something
if c.Request.Method == http.MethodHead {
c.Set(BouncerContextKey, bouncer)
return
}

logger = logger.WithField("name", bouncer.Name)
logger = logger.WithField("name", bouncer.Name)

// 1st time we see this bouncer, we update its IP
if bouncer.IPAddress == "" {
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
// 1st time we see this bouncer, we update its IP
if bouncer.IPAddress == "" {
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()

return
}
return
}
}

useragent := strings.Split(c.Request.UserAgent(), "/")
if len(useragent) != 2 {
logger.Warningf("bad user agent '%s'", c.Request.UserAgent())
useragent = []string{c.Request.UserAgent(), "N/A"}
}
useragent := strings.Split(c.Request.UserAgent(), "/")
if len(useragent) != 2 {
logger.Warningf("bad user agent '%s'", c.Request.UserAgent())
useragent = []string{c.Request.UserAgent(), "N/A"}
}

if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] {
if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil {
logger.Errorf("failed to update bouncer version and type: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
c.Abort()
if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] {
if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil {
logger.Errorf("failed to update bouncer version and type: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
c.Abort()

return
}
return
}

c.Set(BouncerContextKey, bouncer)
}

c.Set(BouncerContextKey, bouncer)
}
Loading