Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/middleware/limiter.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ app.Use(limiter.New(limiter.Config{
return 20
},
Expiration: 30 * time.Second,
ExpirationFunc: func(c fiber.Ctx) time.Duration {
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ExpirationFunc in this example is redundant because it returns the same value as the static Expiration field above it (30 * time.Second). For a more useful demonstration, consider showing either a different value or a dynamic calculation based on the request context, similar to the dynamic expiration example below.

Suggested change
ExpirationFunc: func(c fiber.Ctx) time.Duration {
ExpirationFunc: func(c fiber.Ctx) time.Duration {
if c.Path() == "/login" {
return 60 * time.Second
}

Copilot uses AI. Check for mistakes.
return 30 * time.Second
},
KeyGenerator: func(c fiber.Ctx) string {
return c.Get("x-forwarded-for")
},
Expand Down Expand Up @@ -99,6 +102,21 @@ app.Use(limiter.New(limiter.Config{
}))
```

## Dynamic expiration

You can also calculate the expiration dynamically using the `ExpirationFunc` parameter. It receives the request context and allows you to set a different expiration window for each request.

Example:

```go
app.Use(limiter.New(limiter.Config{
Max: 20,
ExpirationFunc: func(c fiber.Ctx) time.Duration {
return getExpirationForRoute(c.Path())
},
}))
```

## Config

| Property | Type | Description | Default |
Expand All @@ -108,6 +126,7 @@ app.Use(limiter.New(limiter.Config{
| MaxFunc | `func(fiber.Ctx) int` | Function that calculates the maximum number of recent connections within `Expiration` seconds before sending a 429 response. | A function that returns `cfg.Max` |
| KeyGenerator | `func(fiber.Ctx) string` | Function to generate custom keys; uses `c.IP()` by default. | A function using `c.IP()` as the default |
| Expiration | `time.Duration` | Duration to keep request records in memory. | 1 * time.Minute |
| ExpirationFunc | `func(fiber.Ctx) time.Duration` | Function that calculates the expiration duration dynamically. | A function that returns `cfg.Expiration` |
| LimitReached | `fiber.Handler` | Called when a request exceeds the limit. | A function sending a 429 response |
| SkipFailedRequests | `bool` | When set to `true`, requests with status code ≥ 400 aren't counted. | false |
| SkipSuccessfulRequests | `bool` | When set to `true`, requests with status code < 400 aren't counted. | false |
Expand All @@ -129,6 +148,9 @@ var ConfigDefault = Config{
return 5
},
Expiration: 1 * time.Minute,
ExpirationFunc: func(c fiber.Ctx) time.Duration {
return 1 * time.Minute
},
KeyGenerator: func(c fiber.Ctx) string {
return c.IP()
},
Expand Down
15 changes: 15 additions & 0 deletions middleware/limiter/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ type Config struct {
// }
MaxFunc func(c fiber.Ctx) int

// A function to dynamically calculate the expiration time for rate limiter entries
//
// Default: func(c fiber.Ctx) time.Duration {
// return c.Expiration
// }
ExpirationFunc func(c fiber.Ctx) time.Duration

// KeyGenerator allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c fiber.Ctx) string {
Expand Down Expand Up @@ -83,6 +90,9 @@ var ConfigDefault = Config{
MaxFunc: func(_ fiber.Ctx) int {
return defaultLimiterMax
},
ExpirationFunc: func(_ fiber.Ctx) time.Duration {
return 1 * time.Minute
},
KeyGenerator: func(c fiber.Ctx) string {
return c.IP()
},
Expand Down Expand Up @@ -130,5 +140,10 @@ func configDefault(config ...Config) Config {
return cfg.Max
}
}
if cfg.ExpirationFunc == nil {
cfg.ExpirationFunc = func(_ fiber.Ctx) time.Duration {
return cfg.Expiration
}
}
return cfg
}
15 changes: 8 additions & 7 deletions middleware/limiter/limiter_fixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ func (FixedWindow) New(cfg *Config) fiber.Handler {
cfg = &defaultCfg
}

var (
// Limiter variables
mux = &sync.RWMutex{}
expiration = uint64(cfg.Expiration.Seconds())
)
// Limiter variables
mux := &sync.RWMutex{}

// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage, !cfg.DisableValueRedaction)
Expand All @@ -41,6 +38,10 @@ func (FixedWindow) New(cfg *Config) fiber.Handler {
return c.Next()
}

// Generate expiration from generator
expirationDuration := cfg.ExpirationFunc(c)
expiration := uint64(expirationDuration.Seconds())
Comment on lines +41 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Same duration validation needed here.

The same conversion safety concerns from limiter_sliding.go apply here. Please add validation to ensure ExpirationFunc returns a positive, reasonable duration before converting to uint64.

🔎 Suggested validation
 // Generate expiration from generator
 expirationDuration := cfg.ExpirationFunc(c)
+if expirationDuration <= 0 {
+	return fmt.Errorf("limiter: ExpirationFunc must return a positive duration, got %v", expirationDuration)
+}
 expiration := uint64(expirationDuration.Seconds())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Generate expiration from generator
expirationDuration := cfg.ExpirationFunc(c)
expiration := uint64(expirationDuration.Seconds())
// Generate expiration from generator
expirationDuration := cfg.ExpirationFunc(c)
if expirationDuration <= 0 {
return fmt.Errorf("limiter: ExpirationFunc must return a positive duration, got %v", expirationDuration)
}
expiration := uint64(expirationDuration.Seconds())
🤖 Prompt for AI Agents
In middleware/limiter/limiter_fixed.go around lines 41-43, validate the duration
returned by cfg.ExpirationFunc(c) before converting to uint64: ensure the
duration is positive (>0) and clamp it to a safe upper bound to avoid overflow
when calling Seconds(); if the duration is <=0, set a sensible default (e.g.,
1s) and log or warn; if Seconds() exceeds the max uint64 value, clamp to
math.MaxUint64 (or a defined max) before casting to uint64. Implement these
checks immediately after calling ExpirationFunc and use the validated/clamped
value for expiration.


// Get key from request
key := cfg.KeyGenerator(c)

Expand Down Expand Up @@ -78,7 +79,7 @@ func (FixedWindow) New(cfg *Config) fiber.Handler {
remaining := maxRequests - e.currHits

// Update storage
if setErr := manager.set(reqCtx, key, e, cfg.Expiration); setErr != nil {
if setErr := manager.set(reqCtx, key, e, expirationDuration); setErr != nil {
mux.Unlock()
return fmt.Errorf("limiter: failed to persist state: %w", setErr)
}
Expand Down Expand Up @@ -118,7 +119,7 @@ func (FixedWindow) New(cfg *Config) fiber.Handler {
e = entry
e.currHits--
remaining++
if setErr := manager.set(reqCtx, key, e, cfg.Expiration); setErr != nil {
if setErr := manager.set(reqCtx, key, e, expirationDuration); setErr != nil {
mux.Unlock()
return fmt.Errorf("limiter: failed to persist state: %w", setErr)
}
Expand Down
11 changes: 6 additions & 5 deletions middleware/limiter/limiter_sliding.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ func (SlidingWindow) New(cfg *Config) fiber.Handler {
cfg = &defaultCfg
}

var (
// Limiter variables
mux = &sync.RWMutex{}
expiration = uint64(cfg.Expiration.Seconds())
)
// Limiter variables
mux := &sync.RWMutex{}

// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage, !cfg.DisableValueRedaction)
Expand All @@ -43,6 +40,10 @@ func (SlidingWindow) New(cfg *Config) fiber.Handler {
return c.Next()
}

// Generate expiration from generator
expirationDuration := cfg.ExpirationFunc(c)
expiration := uint64(expirationDuration.Seconds())

// Get key from request
key := cfg.KeyGenerator(c)

Expand Down
70 changes: 70 additions & 0 deletions middleware/limiter/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,76 @@ func Test_Limiter_With_Max_Func(t *testing.T) {
require.Equal(t, 200, resp.StatusCode)
}

// go test -run Test_Limiter_Fixed_ExpirationFuncOverridesStaticExpiration -race -v
func Test_Limiter_Fixed_ExpirationFuncOverridesStaticExpiration(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Use(New(Config{
Max: 2,
Expiration: 10 * time.Second,
ExpirationFunc: func(fiber.Ctx) time.Duration { return 2 * time.Second },
LimiterMiddleware: FixedWindow{},
}))

app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)

time.Sleep(3 * time.Second)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

// go test -run Test_Limiter_Sliding_ExpirationFuncOverridesStaticExpiration -race -v
func Test_Limiter_Sliding_ExpirationFuncOverridesStaticExpiration(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Use(New(Config{
Max: 2,
Expiration: 10 * time.Second,
ExpirationFunc: func(fiber.Ctx) time.Duration { return 2 * time.Second },
LimiterMiddleware: SlidingWindow{},
}))

app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)

time.Sleep(5 * time.Second)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

// go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_Concurrency_Store(t *testing.T) {
t.Parallel()
Expand Down
Loading