Skip to content

Commit c824b8d

Browse files
gdamorevishr
authored andcommitted
Add support for encoding.TextUnmarshaler in bind. (#1314)
1 parent 842fc87 commit c824b8d

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

bind.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package echo
22

33
import (
4+
"encoding"
45
"encoding/json"
56
"encoding/xml"
67
"errors"
@@ -21,6 +22,8 @@ type (
2122
DefaultBinder struct{}
2223

2324
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
25+
// Types that don't implement this, but do implement encoding.TextUnmarshaler
26+
// will use that interface instead.
2427
BindUnmarshaler interface {
2528
// UnmarshalParam decodes and assigns a value from an form or query param.
2629
UnmarshalParam(param string) error
@@ -211,12 +214,30 @@ func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) {
211214
return nil, false
212215
}
213216

217+
// textUnmarshaler attempts to unmarshal a reflect.Value into a TextUnmarshaler
218+
func textUnmarshaler(field reflect.Value) (encoding.TextUnmarshaler, bool) {
219+
ptr := reflect.New(field.Type())
220+
if ptr.CanInterface() {
221+
iface := ptr.Interface()
222+
if unmarshaler, ok := iface.(encoding.TextUnmarshaler); ok {
223+
return unmarshaler, ok
224+
}
225+
}
226+
return nil, false
227+
}
228+
214229
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
215230
if unmarshaler, ok := bindUnmarshaler(field); ok {
216231
err := unmarshaler.UnmarshalParam(value)
217232
field.Set(reflect.ValueOf(unmarshaler).Elem())
218233
return true, err
219234
}
235+
if unmarshaler, ok := textUnmarshaler(field); ok {
236+
err := unmarshaler.UnmarshalText([]byte(value))
237+
field.Set(reflect.ValueOf(unmarshaler).Elem())
238+
return true, err
239+
}
240+
220241
return false, nil
221242
}
222243

bind_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ type (
5050
PtrS *string
5151
cantSet string
5252
DoesntExist string
53+
GoT time.Time
54+
GoTptr *time.Time
5355
T Timestamp
5456
Tptr *Timestamp
5557
SA StringArray
@@ -116,6 +118,8 @@ var values = map[string][]string{
116118
"cantSet": {"test"},
117119
"T": {"2016-12-06T19:09:05+01:00"},
118120
"Tptr": {"2016-12-06T19:09:05+01:00"},
121+
"GoT": {"2016-12-06T19:09:05+01:00"},
122+
"GoTptr": {"2016-12-06T19:09:05+01:00"},
119123
"ST": {"bar"},
120124
}
121125

@@ -216,6 +220,28 @@ func TestBindUnmarshalParam(t *testing.T) {
216220
}
217221
}
218222

223+
func TestBindUnmarshalText(t *testing.T) {
224+
e := New()
225+
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
226+
rec := httptest.NewRecorder()
227+
c := e.NewContext(req, rec)
228+
result := struct {
229+
T time.Time `query:"ts"`
230+
TA []time.Time `query:"ta"`
231+
SA StringArray `query:"sa"`
232+
ST Struct
233+
}{}
234+
err := c.Bind(&result)
235+
ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)
236+
if assert.NoError(t, err) {
237+
// assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
238+
assert.Equal(t, ts, result.T)
239+
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
240+
assert.Equal(t, []time.Time{ts, ts}, result.TA)
241+
assert.Equal(t, Struct{"baz"}, result.ST)
242+
}
243+
}
244+
219245
func TestBindUnmarshalParamPtr(t *testing.T) {
220246
e := New()
221247
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
@@ -230,6 +256,20 @@ func TestBindUnmarshalParamPtr(t *testing.T) {
230256
}
231257
}
232258

259+
func TestBindUnmarshalTextPtr(t *testing.T) {
260+
e := New()
261+
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil)
262+
rec := httptest.NewRecorder()
263+
c := e.NewContext(req, rec)
264+
result := struct {
265+
Tptr *time.Time `query:"ts"`
266+
}{}
267+
err := c.Bind(&result)
268+
if assert.NoError(t, err) {
269+
assert.Equal(t, time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC), *result.Tptr)
270+
}
271+
}
272+
233273
func TestBindMultipartForm(t *testing.T) {
234274
body := new(bytes.Buffer)
235275
mw := multipart.NewWriter(body)

0 commit comments

Comments
 (0)