Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions runtime/marshal_urlencode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package runtime

import (
"fmt"
"io"
"net/url"

"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/protobuf/proto"
)

type UrlEncodedDecoder struct {
r io.Reader
}

func NewUrlEncodedDecoder(r io.Reader) Decoder {
return &UrlEncodedDecoder{r: r}
}

func (u *UrlEncodedDecoder) Decode(v interface{}) error {
msg, ok := v.(proto.Message)
if !ok {
return fmt.Errorf("not proto message")
}

formData, err := io.ReadAll(u.r)
if err != nil {
return err
}

values, err := url.ParseQuery(string(formData))
if err != nil {
return err
}

filter := &utilities.DoubleArray{}

err = PopulateQueryParameters(msg, values, filter)
if err != nil {
return err
}

return nil
}

type UrlEncodeMarshal struct {
Marshaler
}

// ContentType means the content type of the response
func (u *UrlEncodeMarshal) ContentType(_ interface{}) string {
return "application/json"
}

func (u *UrlEncodeMarshal) Marshal(v interface{}) ([]byte, error) {
return u.Marshaler.Marshal(v)
}

// NewDecoder indicates how to decode the request
func (u *UrlEncodeMarshal) NewDecoder(r io.Reader) Decoder {
return NewUrlEncodedDecoder(r)
}
150 changes: 150 additions & 0 deletions runtime/marshal_urlencode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package runtime

import (
"bytes"
"net/http"
"net/url"
"strings"
"testing"

"github.com/grpc-ecosystem/grpc-gateway/v2/runtime/internal/examplepb"
"google.golang.org/protobuf/proto"
)

func TestUrlEncodedDecoder_Decode(t *testing.T) {
tests := []struct {
name string
values url.Values
want proto.Message
wantErr bool
}{
{
name: "simple form fields",
values: url.Values{
"single_nested.name": {"test"},
"single_nested.amount": {"42"},
},
want: &examplepb.ABitOfEverything{
SingleNested: &examplepb.ABitOfEverything_Nested{
Name: "test",
Amount: 42,
},
},
wantErr: false,
},
{
name: "fields with special characters",
values: url.Values{
"single_nested.name": {"Hello World!"},
"single_nested.amount": {"123"},
},
want: &examplepb.ABitOfEverything{
SingleNested: &examplepb.ABitOfEverything_Nested{
Name: "Hello World!",
Amount: 123,
},
},
wantErr: false,
},
{
name: "empty input",
values: url.Values{},
want: &examplepb.ABitOfEverything{},
wantErr: false,
},
{
name: "repeated field",
values: url.Values{
"repeated_string_value": {"one", "two", "three"},
},
want: &examplepb.ABitOfEverything{
RepeatedStringValue: []string{"one", "two", "three"},
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader(tt.values.Encode()))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

decoder := NewUrlEncodedDecoder(req.Body)
msg := &examplepb.ABitOfEverything{}

err = decoder.Decode(msg)
if (err != nil) != tt.wantErr {
t.Errorf("UrlEncodedDecoder.Decode() error = %v, wantErr %v", err, tt.wantErr)
return
}

if !tt.wantErr && !proto.Equal(msg, tt.want) {
t.Errorf("UrlEncodedDecoder.Decode() = %v, want %v", msg, tt.want)
}
})
}
}

func TestUrlEncodedDecoder_DecodeNonProto(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader(""))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

decoder := NewUrlEncodedDecoder(req.Body)
var nonProto struct{}

err = decoder.Decode(&nonProto)
if err == nil {
t.Error("UrlEncodedDecoder.Decode() expected error for non-proto message")
}
}

func TestUrlEncodeMarshal_ContentType(t *testing.T) {
m := &UrlEncodeMarshal{}
if got := m.ContentType(nil); got != "application/json" {
t.Errorf("UrlEncodeMarshal.ContentType() = %v, want application/json", got)
}
}

func TestUrlEncodeMarshal_Marshal(t *testing.T) {
msg := &examplepb.ABitOfEverything{
SingleNested: &examplepb.ABitOfEverything_Nested{
Name: "test",
Amount: 42,
},
}

marshaler := &UrlEncodeMarshal{
Marshaler: &JSONPb{},
}

got, err := marshaler.Marshal(msg)
if err != nil {
t.Fatalf("UrlEncodeMarshal.Marshal() error = %v", err)
}

want := []byte(`{"single_nested":{"name":"test","amount":42}}`)
if !bytes.Equal(got, want) {
t.Errorf("UrlEncodeMarshal.Marshal() = %s, want %s", got, want)
}
}

func TestUrlEncodeMarshal_NewDecoder(t *testing.T) {
m := &UrlEncodeMarshal{}
req, err := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader(""))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

decoder := m.NewDecoder(req.Body)

if _, ok := decoder.(*UrlEncodedDecoder); !ok {
t.Error("UrlEncodeMarshal.NewDecoder() did not return *UrlEncodedDecoder")
}
}
30 changes: 20 additions & 10 deletions runtime/marshaler_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,30 @@ import (

// MIMEWildcard is the fallback MIME type used for requests which do not match
// a registered MIME type.
const MIMEWildcard = "*"
const (
MIMEWildcard = "*"
MIMEUrlEncoded = "application/x-www-form-urlencoded"
)

var (
acceptHeader = http.CanonicalHeaderKey("Accept")
contentTypeHeader = http.CanonicalHeaderKey("Content-Type")

defaultMarshaler = &HTTPBodyMarshaler{
Marshaler: &JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
defaultJsonPbMarshaler = &JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
}

defaultMarshaler = &HTTPBodyMarshaler{
Marshaler: defaultJsonPbMarshaler,
}

urlEncodedMarshaler = &UrlEncodeMarshal{
Marshaler: defaultJsonPbMarshaler,
}
)

Expand Down Expand Up @@ -93,7 +102,8 @@ func (m marshalerRegistry) add(mime string, marshaler Marshaler) error {
func makeMarshalerMIMERegistry() marshalerRegistry {
return marshalerRegistry{
mimeMap: map[string]Marshaler{
MIMEWildcard: defaultMarshaler,
MIMEWildcard: defaultMarshaler,
MIMEUrlEncoded: urlEncodedMarshaler,
},
}
}
Expand Down
Loading