Skip to content

Commit ef04a8a

Browse files
efectngabyReneWerner87
authored
🐛 bug: Fix square bracket notation in Multipart FormData (#3235)
* 🐛 bug: add square bracket notation support to BindMultipart * Fix golangci-lint issues * Fixing undef variable * Fix more lint issues * test * update1 * improve coverage * fix linter * reduce code duplication * reduce code duplications in bindMultipart --------- Co-authored-by: Juan Calderon-Perez <[email protected]> Co-authored-by: René <[email protected]>
1 parent d0e767f commit ef04a8a

File tree

8 files changed

+186
-75
lines changed

8 files changed

+186
-75
lines changed

bind_test.go

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"errors"
99
"fmt"
10+
"mime/multipart"
1011
"net/http/httptest"
1112
"reflect"
1213
"testing"
@@ -886,7 +887,8 @@ func Test_Bind_Body(t *testing.T) {
886887
reqBody := []byte(`{"name":"john"}`)
887888

888889
type Demo struct {
889-
Name string `json:"name" xml:"name" form:"name" query:"name"`
890+
Name string `json:"name" xml:"name" form:"name" query:"name"`
891+
Names []string `json:"names" xml:"names" form:"names" query:"names"`
890892
}
891893

892894
// Helper function to test compressed bodies
@@ -996,6 +998,48 @@ func Test_Bind_Body(t *testing.T) {
996998
Data []Demo `query:"data"`
997999
}
9981000

1001+
t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
1002+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
1003+
c.Request().Reset()
1004+
1005+
buf := &bytes.Buffer{}
1006+
writer := multipart.NewWriter(buf)
1007+
require.NoError(t, writer.WriteField("data.0.name", "john"))
1008+
require.NoError(t, writer.WriteField("data.1.name", "doe"))
1009+
require.NoError(t, writer.Close())
1010+
1011+
c.Request().Header.SetContentType(writer.FormDataContentType())
1012+
c.Request().SetBody(buf.Bytes())
1013+
c.Request().Header.SetContentLength(len(c.Body()))
1014+
1015+
cq := new(CollectionQuery)
1016+
require.NoError(t, c.Bind().Body(cq))
1017+
require.Len(t, cq.Data, 2)
1018+
require.Equal(t, "john", cq.Data[0].Name)
1019+
require.Equal(t, "doe", cq.Data[1].Name)
1020+
})
1021+
1022+
t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
1023+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
1024+
c.Request().Reset()
1025+
1026+
buf := &bytes.Buffer{}
1027+
writer := multipart.NewWriter(buf)
1028+
require.NoError(t, writer.WriteField("data[0][name]", "john"))
1029+
require.NoError(t, writer.WriteField("data[1][name]", "doe"))
1030+
require.NoError(t, writer.Close())
1031+
1032+
c.Request().Header.SetContentType(writer.FormDataContentType())
1033+
c.Request().SetBody(buf.Bytes())
1034+
c.Request().Header.SetContentLength(len(c.Body()))
1035+
1036+
cq := new(CollectionQuery)
1037+
require.NoError(t, c.Bind().Body(cq))
1038+
require.Len(t, cq.Data, 2)
1039+
require.Equal(t, "john", cq.Data[0].Name)
1040+
require.Equal(t, "doe", cq.Data[1].Name)
1041+
})
1042+
9991043
t.Run("CollectionQuerySquareBrackets", func(t *testing.T) {
10001044
c := app.AcquireCtx(&fasthttp.RequestCtx{})
10011045
c.Request().Reset()
@@ -1192,9 +1236,57 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
11921236
Name string `form:"name"`
11931237
}
11941238

1195-
body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")
1239+
buf := &bytes.Buffer{}
1240+
writer := multipart.NewWriter(buf)
1241+
require.NoError(b, writer.WriteField("name", "john"))
1242+
require.NoError(b, writer.Close())
1243+
body := buf.Bytes()
1244+
1245+
c.Request().SetBody(body)
1246+
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
1247+
c.Request().Header.SetContentLength(len(body))
1248+
d := new(Demo)
1249+
1250+
b.ReportAllocs()
1251+
b.ResetTimer()
1252+
1253+
for n := 0; n < b.N; n++ {
1254+
err = c.Bind().Body(d)
1255+
}
1256+
1257+
require.NoError(b, err)
1258+
require.Equal(b, "john", d.Name)
1259+
}
1260+
1261+
// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4
1262+
func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) {
1263+
var err error
1264+
1265+
app := New()
1266+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
1267+
1268+
type Person struct {
1269+
Name string `form:"name"`
1270+
Age int `form:"age"`
1271+
}
1272+
1273+
type Demo struct {
1274+
Name string `form:"name"`
1275+
Persons []Person `form:"persons"`
1276+
}
1277+
1278+
buf := &bytes.Buffer{}
1279+
writer := multipart.NewWriter(buf)
1280+
require.NoError(b, writer.WriteField("name", "john"))
1281+
require.NoError(b, writer.WriteField("persons.0.name", "john"))
1282+
require.NoError(b, writer.WriteField("persons[0][age]", "10"))
1283+
require.NoError(b, writer.WriteField("persons[1][name]", "doe"))
1284+
require.NoError(b, writer.WriteField("persons.1.age", "20"))
1285+
require.NoError(b, writer.Close())
1286+
body := buf.Bytes()
1287+
11961288
c.Request().SetBody(body)
1197-
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`)
1289+
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
11981290
c.Request().Header.SetContentLength(len(body))
11991291
d := new(Demo)
12001292

@@ -1204,8 +1296,13 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
12041296
for n := 0; n < b.N; n++ {
12051297
err = c.Bind().Body(d)
12061298
}
1299+
12071300
require.NoError(b, err)
12081301
require.Equal(b, "john", d.Name)
1302+
require.Equal(b, "john", d.Persons[0].Name)
1303+
require.Equal(b, 10, d.Persons[0].Age)
1304+
require.Equal(b, "doe", d.Persons[1].Name)
1305+
require.Equal(b, 20, d.Persons[1].Age)
12091306
}
12101307

12111308
// go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4

binder/cookie.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package binder
22

33
import (
4-
"reflect"
5-
"strings"
6-
74
"github.com/gofiber/utils/v2"
85
"github.com/valyala/fasthttp"
96
)
@@ -30,15 +27,7 @@ func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error {
3027

3128
k := utils.UnsafeString(key)
3229
v := utils.UnsafeString(val)
33-
34-
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
35-
values := strings.Split(v, ",")
36-
for i := 0; i < len(values); i++ {
37-
data[k] = append(data[k], values[i])
38-
}
39-
} else {
40-
data[k] = append(data[k], v)
41-
}
30+
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
4231
})
4332

4433
if err != nil {

binder/form.go

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package binder
22

33
import (
4-
"reflect"
5-
"strings"
6-
74
"github.com/gofiber/utils/v2"
85
"github.com/valyala/fasthttp"
96
)
@@ -37,19 +34,7 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error {
3734

3835
k := utils.UnsafeString(key)
3936
v := utils.UnsafeString(val)
40-
41-
if strings.Contains(k, "[") {
42-
k, err = parseParamSquareBrackets(k)
43-
}
44-
45-
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
46-
values := strings.Split(v, ",")
47-
for i := 0; i < len(values); i++ {
48-
data[k] = append(data[k], values[i])
49-
}
50-
} else {
51-
data[k] = append(data[k], v)
52-
}
37+
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
5338
})
5439

5540
if err != nil {
@@ -61,12 +46,20 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error {
6146

6247
// bindMultipart parses the request body and returns the result.
6348
func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error {
64-
data, err := req.MultipartForm()
49+
multipartForm, err := req.MultipartForm()
6550
if err != nil {
6651
return err
6752
}
6853

69-
return parse(b.Name(), out, data.Value)
54+
data := make(map[string][]string)
55+
for key, values := range multipartForm.Value {
56+
err = formatBindData(out, data, key, values, b.EnableSplitting, true)
57+
if err != nil {
58+
return err
59+
}
60+
}
61+
62+
return parse(b.Name(), out, data)
7063
}
7164

7265
// Reset resets the FormBinding binder.

binder/form_test.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,14 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
9393
}
9494
require.Equal(t, "form", b.Name())
9595

96+
type Post struct {
97+
Title string `form:"title"`
98+
}
99+
96100
type User struct {
97101
Name string `form:"name"`
98102
Names []string `form:"names"`
103+
Posts []Post `form:"posts"`
99104
Age int `form:"age"`
100105
}
101106
var user User
@@ -106,9 +111,13 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
106111
mw := multipart.NewWriter(buf)
107112

108113
require.NoError(t, mw.WriteField("name", "john"))
109-
require.NoError(t, mw.WriteField("names", "john"))
114+
require.NoError(t, mw.WriteField("names", "john,eric"))
110115
require.NoError(t, mw.WriteField("names", "doe"))
111116
require.NoError(t, mw.WriteField("age", "42"))
117+
require.NoError(t, mw.WriteField("posts[0][title]", "post1"))
118+
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
119+
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))
120+
112121
require.NoError(t, mw.Close())
113122

114123
req.Header.SetContentType(mw.FormDataContentType())
@@ -125,6 +134,11 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
125134
require.Equal(t, 42, user.Age)
126135
require.Contains(t, user.Names, "john")
127136
require.Contains(t, user.Names, "doe")
137+
require.Contains(t, user.Names, "eric")
138+
require.Len(t, user.Posts, 3)
139+
require.Equal(t, "post1", user.Posts[0].Title)
140+
require.Equal(t, "post2", user.Posts[1].Title)
141+
require.Equal(t, "post3", user.Posts[2].Title)
128142
}
129143

130144
func Benchmark_FormBinder_BindMultipart(b *testing.B) {

binder/header.go

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package binder
22

33
import (
4-
"reflect"
5-
"strings"
6-
74
"github.com/gofiber/utils/v2"
85
"github.com/valyala/fasthttp"
96
)
@@ -21,20 +18,21 @@ func (*HeaderBinding) Name() string {
2118
// Bind parses the request header and returns the result.
2219
func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error {
2320
data := make(map[string][]string)
21+
var err error
2422
req.Header.VisitAll(func(key, val []byte) {
23+
if err != nil {
24+
return
25+
}
26+
2527
k := utils.UnsafeString(key)
2628
v := utils.UnsafeString(val)
27-
28-
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
29-
values := strings.Split(v, ",")
30-
for i := 0; i < len(values); i++ {
31-
data[k] = append(data[k], values[i])
32-
}
33-
} else {
34-
data[k] = append(data[k], v)
35-
}
29+
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
3630
})
3731

32+
if err != nil {
33+
return err
34+
}
35+
3836
return parse(b.Name(), out, data)
3937
}
4038

binder/mapping.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error {
107107
func parseToMap(ptr any, data map[string][]string) error {
108108
elem := reflect.TypeOf(ptr).Elem()
109109

110-
switch elem.Kind() { //nolint:exhaustive // it's not necessary to check all types
110+
switch elem.Kind() {
111111
case reflect.Slice:
112112
newMap, ok := ptr.(map[string][]string)
113113
if !ok {
@@ -130,6 +130,8 @@ func parseToMap(ptr any, data map[string][]string) error {
130130
}
131131
newMap[k] = v[len(v)-1]
132132
}
133+
default:
134+
return nil // it's not necessary to check all types
133135
}
134136

135137
return nil
@@ -247,3 +249,37 @@ func FilterFlags(content string) string {
247249
}
248250
return content
249251
}
252+
253+
func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
254+
var err error
255+
if supportBracketNotation && strings.Contains(key, "[") {
256+
key, err = parseParamSquareBrackets(key)
257+
if err != nil {
258+
return err
259+
}
260+
}
261+
262+
switch v := any(value).(type) {
263+
case string:
264+
assignBindData(out, data, key, v, enableSplitting)
265+
case []string:
266+
for _, val := range v {
267+
assignBindData(out, data, key, val, enableSplitting)
268+
}
269+
default:
270+
return fmt.Errorf("unsupported value type: %T", value)
271+
}
272+
273+
return err
274+
}
275+
276+
func assignBindData(out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay
277+
if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key) {
278+
values := strings.Split(value, ",")
279+
for i := 0; i < len(values); i++ {
280+
data[key] = append(data[key], values[i])
281+
}
282+
} else {
283+
data[key] = append(data[key], value)
284+
}
285+
}

binder/query.go

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package binder
22

33
import (
4-
"reflect"
5-
"strings"
6-
74
"github.com/gofiber/utils/v2"
85
"github.com/valyala/fasthttp"
96
)
@@ -30,19 +27,7 @@ func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error {
3027

3128
k := utils.UnsafeString(key)
3229
v := utils.UnsafeString(val)
33-
34-
if strings.Contains(k, "[") {
35-
k, err = parseParamSquareBrackets(k)
36-
}
37-
38-
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
39-
values := strings.Split(v, ",")
40-
for i := 0; i < len(values); i++ {
41-
data[k] = append(data[k], values[i])
42-
}
43-
} else {
44-
data[k] = append(data[k], v)
45-
}
30+
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
4631
})
4732

4833
if err != nil {

0 commit comments

Comments
 (0)