Skip to content

Commit 626f13e

Browse files
committed
CSRF/RequestID mw: switch math/random usage to crypto/random
1 parent 3f8ae15 commit 626f13e

File tree

6 files changed

+47
-9
lines changed

6 files changed

+47
-9
lines changed

middleware/csrf.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"time"
77

88
"github.com/labstack/echo/v4"
9-
"github.com/labstack/gommon/random"
109
)
1110

1211
type (
@@ -103,6 +102,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
103102
if config.TokenLength == 0 {
104103
config.TokenLength = DefaultCSRFConfig.TokenLength
105104
}
105+
106106
if config.TokenLookup == "" {
107107
config.TokenLookup = DefaultCSRFConfig.TokenLookup
108108
}
@@ -132,7 +132,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
132132

133133
token := ""
134134
if k, err := c.Cookie(config.CookieName); err != nil {
135-
token = random.String(config.TokenLength) // Generate token
135+
token = randomString(config.TokenLength)
136136
} else {
137137
token = k.Value // Reuse token
138138
}

middleware/csrf_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"testing"
99

1010
"github.com/labstack/echo/v4"
11-
"github.com/labstack/gommon/random"
1211
"github.com/stretchr/testify/assert"
1312
)
1413

@@ -233,7 +232,7 @@ func TestCSRF(t *testing.T) {
233232
assert.Error(t, h(c))
234233

235234
// Valid CSRF token
236-
token := random.String(32)
235+
token := randomString(32)
237236
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
238237
req.Header.Set(echo.HeaderXCSRFToken, token)
239238
if assert.NoError(t, h(c)) {

middleware/rate_limiter_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"time"
1111

1212
"github.com/labstack/echo/v4"
13-
"github.com/labstack/gommon/random"
1413
"github.com/stretchr/testify/assert"
1514
"golang.org/x/time/rate"
1615
)
@@ -410,7 +409,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
410409
func generateAddressList(count int) []string {
411410
addrs := make([]string, count)
412411
for i := 0; i < count; i++ {
413-
addrs[i] = random.String(15)
412+
addrs[i] = randomString(15)
414413
}
415414
return addrs
416415
}

middleware/request_id.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package middleware
22

33
import (
44
"github.com/labstack/echo/v4"
5-
"github.com/labstack/gommon/random"
65
)
76

87
type (
@@ -12,7 +11,7 @@ type (
1211
Skipper Skipper
1312

1413
// Generator defines a function to generate an ID.
15-
// Optional. Default value random.String(32).
14+
// Optional. Defaults to generator for random string of length 32.
1615
Generator func() string
1716

1817
// RequestIDHandler defines a function which is executed for a request id.
@@ -73,5 +72,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
7372
}
7473

7574
func generator() string {
76-
return random.String(32)
75+
return randomString(32)
7776
}

middleware/util.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package middleware
22

33
import (
4+
"crypto/rand"
5+
"fmt"
46
"strings"
57
)
68

@@ -52,3 +54,18 @@ func matchSubdomain(domain, pattern string) bool {
5254
}
5355
return false
5456
}
57+
58+
func randomString(length uint8) string {
59+
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
60+
61+
bytes := make([]byte, length)
62+
_, err := rand.Read(bytes)
63+
if err != nil {
64+
// we are out of random. let the request fail
65+
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
66+
}
67+
for i, b := range bytes {
68+
bytes[i] = charset[b%byte(len(charset))]
69+
}
70+
return string(bytes)
71+
}

middleware/util_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,27 @@ func Test_matchSubdomain(t *testing.T) {
9393
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
9494
}
9595
}
96+
97+
func TestRandomString(t *testing.T) {
98+
var testCases = []struct {
99+
name string
100+
whenLength uint8
101+
expect string
102+
}{
103+
{
104+
name: "ok, 16",
105+
whenLength: 16,
106+
},
107+
{
108+
name: "ok, 32",
109+
whenLength: 32,
110+
},
111+
}
112+
113+
for _, tc := range testCases {
114+
t.Run(tc.name, func(t *testing.T) {
115+
uid := randomString(tc.whenLength)
116+
assert.Len(t, uid, int(tc.whenLength))
117+
})
118+
}
119+
}

0 commit comments

Comments
 (0)