Skip to content

Commit 7828135

Browse files
committed
Add Context.Get generic functions
1 parent 321530d commit 7828135

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

context_generic.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// SPDX-License-Identifier: MIT
2+
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
3+
4+
package echo
5+
6+
import "errors"
7+
8+
// ErrNonExistentKey is error that is returned when key does not exist
9+
var ErrNonExistentKey = errors.New("non existent key")
10+
11+
// ErrInvalidKeyType is error that is returned when the value is not castable to expected type.
12+
var ErrInvalidKeyType = errors.New("invalid key type")
13+
14+
// ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing.
15+
// Returns ErrInvalidKeyType error if the value is not castable to type T.
16+
func ContextGet[T any](c Context, key string) (T, error) {
17+
val := c.Get(key)
18+
if val == any(nil) {
19+
var zero T
20+
return zero, ErrNonExistentKey
21+
}
22+
23+
typed, ok := val.(T)
24+
if !ok {
25+
var zero T
26+
return zero, ErrInvalidKeyType
27+
}
28+
29+
return typed, nil
30+
}
31+
32+
// ContextGetOr retrieves a value from the context store or returns a default value when the key
33+
// is missing. Returns ErrInvalidKeyType error if the value is not castable to type T.
34+
func ContextGetOr[T any](c Context, key string, defaultValue T) (T, error) {
35+
typed, err := ContextGet[T](c, key)
36+
if err == ErrNonExistentKey {
37+
return defaultValue, nil
38+
}
39+
return typed, err
40+
}

context_generic_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package echo
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestContextGetOK(t *testing.T) {
10+
e := New()
11+
c := e.NewContext(nil, nil)
12+
13+
c.Set("key", int64(123))
14+
15+
v, err := ContextGet[int64](c, "key")
16+
assert.NoError(t, err)
17+
assert.Equal(t, int64(123), v)
18+
}
19+
20+
func TestContextGetNonExistentKey(t *testing.T) {
21+
e := New()
22+
c := e.NewContext(nil, nil)
23+
24+
c.Set("key", int64(123))
25+
26+
v, err := ContextGet[int64](c, "nope")
27+
assert.ErrorIs(t, err, ErrNonExistentKey)
28+
assert.Equal(t, int64(0), v)
29+
}
30+
31+
func TestContextGetInvalidCast(t *testing.T) {
32+
e := New()
33+
c := e.NewContext(nil, nil)
34+
35+
c.Set("key", int64(123))
36+
37+
v, err := ContextGet[bool](c, "key")
38+
assert.ErrorIs(t, err, ErrInvalidKeyType)
39+
assert.Equal(t, false, v)
40+
}
41+
42+
func TestContextGetOrOK(t *testing.T) {
43+
e := New()
44+
c := e.NewContext(nil, nil)
45+
46+
c.Set("key", int64(123))
47+
48+
v, err := ContextGetOr[int64](c, "key", 999)
49+
assert.NoError(t, err)
50+
assert.Equal(t, int64(123), v)
51+
}
52+
53+
func TestContextGetOrNonExistentKey(t *testing.T) {
54+
e := New()
55+
c := e.NewContext(nil, nil)
56+
57+
c.Set("key", int64(123))
58+
59+
v, err := ContextGetOr[int64](c, "nope", 999)
60+
assert.NoError(t, err)
61+
assert.Equal(t, int64(999), v)
62+
}
63+
64+
func TestContextGetOrInvalidCast(t *testing.T) {
65+
e := New()
66+
c := e.NewContext(nil, nil)
67+
68+
c.Set("key", int64(123))
69+
70+
v, err := ContextGetOr[float32](c, "key", float32(999))
71+
assert.ErrorIs(t, err, ErrInvalidKeyType)
72+
assert.Equal(t, float32(0), v)
73+
}

0 commit comments

Comments
 (0)