Skip to content

Commit e325e2c

Browse files
authored
gin middleware: drop closures (#4186)
no need to use MiddlewareFunc if everything needs to be run for each request
1 parent 482fa87 commit e325e2c

File tree

4 files changed

+84
-94
lines changed

4 files changed

+84
-94
lines changed

pkg/apiserver/apiserver.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,9 @@ func recoverFromPanic(c *gin.Context) {
108108
}
109109

110110
// CustomRecoveryWithWriter returns a middleware for a writer that recovers from any panics and writes a 500 if there was one.
111-
func CustomRecoveryWithWriter() gin.HandlerFunc {
112-
return func(c *gin.Context) {
113-
defer recoverFromPanic(c)
114-
c.Next()
115-
}
111+
func CustomRecoveryWithWriter(c *gin.Context) {
112+
defer recoverFromPanic(c)
113+
c.Next()
116114
}
117115

118116
// NewServer creates a LAPI server.
@@ -179,7 +177,7 @@ func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg, accessLo
179177
router.NoRoute(func(c *gin.Context) {
180178
c.JSON(http.StatusNotFound, gin.H{"message": "Page or Method not found"})
181179
})
182-
router.Use(CustomRecoveryWithWriter())
180+
router.Use(CustomRecoveryWithWriter)
183181

184182
controller := &controllers.Controller{
185183
DBClient: dbClient,

pkg/apiserver/controllers/controller.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func (c *Controller) NewV1() error {
9696
}
9797

9898
c.Router.GET("/health", gin.WrapF(serveHealth()))
99-
c.Router.Use(v1.PrometheusMiddleware())
99+
c.Router.Use(v1.PrometheusMiddleware)
100100
// We don't want to compress the response body as it would likely break some existing bouncers
101101
// But we do want to automatically uncompress incoming requests
102102
c.Router.Use(gzip.Gzip(gzip.NoCompression, gzip.WithDecompressOnly(), gzip.WithDecompressFn(gzip.DefaultDecompressHandle)))
@@ -116,7 +116,7 @@ func (c *Controller) NewV1() error {
116116

117117
jwtAuth := groupV1.Group("")
118118
jwtAuth.GET("/refresh_token", c.HandlerV1.Middlewares.JWT.Middleware.RefreshHandler)
119-
jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware())
119+
jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware)
120120
{
121121
jwtAuth.POST("/alerts", c.HandlerV1.CreateAlert)
122122
jwtAuth.GET("/alerts", c.HandlerV1.FindAlerts)
@@ -137,7 +137,7 @@ func (c *Controller) NewV1() error {
137137
}
138138

139139
apiKeyAuth := groupV1.Group("")
140-
apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.MiddlewareFunc(), v1.PrometheusBouncersMiddleware())
140+
apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.Middleware, v1.PrometheusBouncersMiddleware)
141141
{
142142
apiKeyAuth.GET("/decisions", c.HandlerV1.GetDecision)
143143
apiKeyAuth.HEAD("/decisions", c.HandlerV1.GetDecision)
@@ -146,7 +146,7 @@ func (c *Controller) NewV1() error {
146146
}
147147

148148
eitherAuth := groupV1.Group("")
149-
eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.MiddlewareFunc()))
149+
eitherAuth.Use(eitherAuthMiddleware(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), c.HandlerV1.Middlewares.APIKey.Middleware))
150150
{
151151
eitherAuth.POST("/usage-metrics", c.HandlerV1.UsageMetrics)
152152
}

pkg/apiserver/controllers/v1/metrics.go

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,51 +31,45 @@ func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) {
3131
}).Inc()
3232
}
3333

34-
func PrometheusMachinesMiddleware() gin.HandlerFunc {
35-
return func(c *gin.Context) {
36-
machineID, _ := getMachineIDFromContext(c)
37-
if machineID == "" {
38-
return
39-
}
40-
41-
metrics.LapiMachineHits.With(prometheus.Labels{
42-
"machine": machineID,
43-
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
44-
"method": c.Request.Method,
45-
}).Inc()
34+
func PrometheusMachinesMiddleware(c *gin.Context) {
35+
machineID, _ := getMachineIDFromContext(c)
36+
if machineID == "" {
37+
return
4638
}
47-
}
4839

49-
func PrometheusBouncersMiddleware() gin.HandlerFunc {
50-
return func(c *gin.Context) {
51-
bouncer, _ := getBouncerFromContext(c)
52-
if bouncer == nil {
53-
return
54-
}
40+
metrics.LapiMachineHits.With(prometheus.Labels{
41+
"machine": machineID,
42+
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
43+
"method": c.Request.Method,
44+
}).Inc()
45+
}
5546

56-
metrics.LapiBouncerHits.With(prometheus.Labels{
57-
"bouncer": bouncer.Name,
58-
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
59-
"method": c.Request.Method,
60-
}).Inc()
47+
func PrometheusBouncersMiddleware(c *gin.Context) {
48+
bouncer, _ := getBouncerFromContext(c)
49+
if bouncer == nil {
50+
return
6151
}
52+
53+
metrics.LapiBouncerHits.With(prometheus.Labels{
54+
"bouncer": bouncer.Name,
55+
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
56+
"method": c.Request.Method,
57+
}).Inc()
6258
}
6359

64-
func PrometheusMiddleware() gin.HandlerFunc {
65-
return func(c *gin.Context) {
66-
startTime := time.Now()
60+
func PrometheusMiddleware(c *gin.Context) {
61+
startTime := time.Now()
6762

68-
metrics.LapiRouteHits.With(prometheus.Labels{
69-
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
70-
"method": c.Request.Method,
71-
}).Inc()
72-
c.Next()
63+
metrics.LapiRouteHits.With(prometheus.Labels{
64+
"route": cmp.Or(c.FullPath(), "invalid-endpoint"),
65+
"method": c.Request.Method,
66+
}).Inc()
67+
c.Next()
7368

74-
elapsed := time.Since(startTime)
75-
metrics.LapiResponseTime.With(
76-
prometheus.Labels{
77-
"method": c.Request.Method,
78-
"endpoint": c.FullPath(),
79-
}).Observe(elapsed.Seconds())
80-
}
69+
elapsed := time.Since(startTime)
70+
metrics.LapiResponseTime.With(
71+
prometheus.Labels{
72+
"method": c.Request.Method,
73+
"endpoint": c.FullPath(),
74+
}).Observe(elapsed.Seconds())
8175
}

pkg/apiserver/middlewares/v1/api_key.go

Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -216,65 +216,63 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer {
216216
return bouncer
217217
}
218218

219-
func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
220-
return func(c *gin.Context) {
221-
var bouncer *ent.Bouncer
219+
func (a *APIKey) Middleware(c *gin.Context) {
220+
var bouncer *ent.Bouncer
222221

223-
ctx := c.Request.Context()
222+
ctx := c.Request.Context()
224223

225-
clientIP := c.ClientIP()
224+
clientIP := c.ClientIP()
226225

227-
logger := log.WithField("ip", clientIP)
226+
logger := log.WithField("ip", clientIP)
228227

229-
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
230-
bouncer = a.authTLS(c, logger)
231-
} else {
232-
bouncer = a.authPlain(c, logger)
233-
}
228+
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
229+
bouncer = a.authTLS(c, logger)
230+
} else {
231+
bouncer = a.authPlain(c, logger)
232+
}
234233

235-
if bouncer == nil {
236-
// XXX: StatusUnauthorized?
237-
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
238-
c.Abort()
234+
if bouncer == nil {
235+
// XXX: StatusUnauthorized?
236+
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
237+
c.Abort()
239238

240-
return
241-
}
239+
return
240+
}
242241

243-
// Appsec request, return immediately if we found something
244-
if c.Request.Method == http.MethodHead {
245-
c.Set(BouncerContextKey, bouncer)
246-
return
247-
}
242+
// Appsec request, return immediately if we found something
243+
if c.Request.Method == http.MethodHead {
244+
c.Set(BouncerContextKey, bouncer)
245+
return
246+
}
248247

249-
logger = logger.WithField("name", bouncer.Name)
248+
logger = logger.WithField("name", bouncer.Name)
250249

251-
// 1st time we see this bouncer, we update its IP
252-
if bouncer.IPAddress == "" {
253-
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
254-
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
255-
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
256-
c.Abort()
250+
// 1st time we see this bouncer, we update its IP
251+
if bouncer.IPAddress == "" {
252+
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
253+
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
254+
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
255+
c.Abort()
257256

258-
return
259-
}
257+
return
260258
}
259+
}
261260

262-
useragent := strings.Split(c.Request.UserAgent(), "/")
263-
if len(useragent) != 2 {
264-
logger.Warningf("bad user agent '%s'", c.Request.UserAgent())
265-
useragent = []string{c.Request.UserAgent(), "N/A"}
266-
}
261+
useragent := strings.Split(c.Request.UserAgent(), "/")
262+
if len(useragent) != 2 {
263+
logger.Warningf("bad user agent '%s'", c.Request.UserAgent())
264+
useragent = []string{c.Request.UserAgent(), "N/A"}
265+
}
267266

268-
if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] {
269-
if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil {
270-
logger.Errorf("failed to update bouncer version and type: %s", err)
271-
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
272-
c.Abort()
267+
if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] {
268+
if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil {
269+
logger.Errorf("failed to update bouncer version and type: %s", err)
270+
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
271+
c.Abort()
273272

274-
return
275-
}
273+
return
276274
}
277-
278-
c.Set(BouncerContextKey, bouncer)
279275
}
276+
277+
c.Set(BouncerContextKey, bouncer)
280278
}

0 commit comments

Comments
 (0)