diff --git a/runtime/marshal_urlencode.go b/runtime/marshal_urlencode.go new file mode 100644 index 00000000000..8de188a4cc9 --- /dev/null +++ b/runtime/marshal_urlencode.go @@ -0,0 +1,62 @@ +package runtime + +import ( + "fmt" + "io" + "net/url" + + "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" + "google.golang.org/protobuf/proto" +) + +type UrlEncodedDecoder struct { + r io.Reader +} + +func NewUrlEncodedDecoder(r io.Reader) Decoder { + return &UrlEncodedDecoder{r: r} +} + +func (u *UrlEncodedDecoder) Decode(v interface{}) error { + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("not proto message") + } + + formData, err := io.ReadAll(u.r) + if err != nil { + return err + } + + values, err := url.ParseQuery(string(formData)) + if err != nil { + return err + } + + filter := &utilities.DoubleArray{} + + err = PopulateQueryParameters(msg, values, filter) + if err != nil { + return err + } + + return nil +} + +type UrlEncodeMarshal struct { + Marshaler +} + +// ContentType means the content type of the response +func (u *UrlEncodeMarshal) ContentType(_ interface{}) string { + return "application/json" +} + +func (u *UrlEncodeMarshal) Marshal(v interface{}) ([]byte, error) { + return u.Marshaler.Marshal(v) +} + +// NewDecoder indicates how to decode the request +func (u *UrlEncodeMarshal) NewDecoder(r io.Reader) Decoder { + return NewUrlEncodedDecoder(r) +} diff --git a/runtime/marshal_urlencode_test.go b/runtime/marshal_urlencode_test.go new file mode 100644 index 00000000000..4aaeeed6c36 --- /dev/null +++ b/runtime/marshal_urlencode_test.go @@ -0,0 +1,150 @@ +package runtime + +import ( + "bytes" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime/internal/examplepb" + "google.golang.org/protobuf/proto" +) + +func TestUrlEncodedDecoder_Decode(t *testing.T) { + tests := []struct { + name string + values url.Values + want proto.Message + wantErr bool + }{ + { + name: "simple form fields", + values: url.Values{ + "single_nested.name": {"test"}, + "single_nested.amount": {"42"}, + }, + want: &examplepb.ABitOfEverything{ + SingleNested: &examplepb.ABitOfEverything_Nested{ + Name: "test", + Amount: 42, + }, + }, + wantErr: false, + }, + { + name: "fields with special characters", + values: url.Values{ + "single_nested.name": {"Hello World!"}, + "single_nested.amount": {"123"}, + }, + want: &examplepb.ABitOfEverything{ + SingleNested: &examplepb.ABitOfEverything_Nested{ + Name: "Hello World!", + Amount: 123, + }, + }, + wantErr: false, + }, + { + name: "empty input", + values: url.Values{}, + want: &examplepb.ABitOfEverything{}, + wantErr: false, + }, + { + name: "repeated field", + values: url.Values{ + "repeated_string_value": {"one", "two", "three"}, + }, + want: &examplepb.ABitOfEverything{ + RepeatedStringValue: []string{"one", "two", "three"}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader(tt.values.Encode())) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + decoder := NewUrlEncodedDecoder(req.Body) + msg := &examplepb.ABitOfEverything{} + + err = decoder.Decode(msg) + if (err != nil) != tt.wantErr { + t.Errorf("UrlEncodedDecoder.Decode() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && !proto.Equal(msg, tt.want) { + t.Errorf("UrlEncodedDecoder.Decode() = %v, want %v", msg, tt.want) + } + }) + } +} + +func TestUrlEncodedDecoder_DecodeNonProto(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("")) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + decoder := NewUrlEncodedDecoder(req.Body) + var nonProto struct{} + + err = decoder.Decode(&nonProto) + if err == nil { + t.Error("UrlEncodedDecoder.Decode() expected error for non-proto message") + } +} + +func TestUrlEncodeMarshal_ContentType(t *testing.T) { + m := &UrlEncodeMarshal{} + if got := m.ContentType(nil); got != "application/json" { + t.Errorf("UrlEncodeMarshal.ContentType() = %v, want application/json", got) + } +} + +func TestUrlEncodeMarshal_Marshal(t *testing.T) { + msg := &examplepb.ABitOfEverything{ + SingleNested: &examplepb.ABitOfEverything_Nested{ + Name: "test", + Amount: 42, + }, + } + + marshaler := &UrlEncodeMarshal{ + Marshaler: &JSONPb{}, + } + + got, err := marshaler.Marshal(msg) + if err != nil { + t.Fatalf("UrlEncodeMarshal.Marshal() error = %v", err) + } + + want := []byte(`{"single_nested":{"name":"test","amount":42}}`) + if !bytes.Equal(got, want) { + t.Errorf("UrlEncodeMarshal.Marshal() = %s, want %s", got, want) + } +} + +func TestUrlEncodeMarshal_NewDecoder(t *testing.T) { + m := &UrlEncodeMarshal{} + req, err := http.NewRequest(http.MethodPost, "http://example.com", strings.NewReader("")) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + decoder := m.NewDecoder(req.Body) + + if _, ok := decoder.(*UrlEncodedDecoder); !ok { + t.Error("UrlEncodeMarshal.NewDecoder() did not return *UrlEncodedDecoder") + } +} diff --git a/runtime/marshaler_registry.go b/runtime/marshaler_registry.go index 07c28112c89..68b7bc54623 100644 --- a/runtime/marshaler_registry.go +++ b/runtime/marshaler_registry.go @@ -11,21 +11,30 @@ import ( // MIMEWildcard is the fallback MIME type used for requests which do not match // a registered MIME type. -const MIMEWildcard = "*" +const ( + MIMEWildcard = "*" + MIMEUrlEncoded = "application/x-www-form-urlencoded" +) var ( acceptHeader = http.CanonicalHeaderKey("Accept") contentTypeHeader = http.CanonicalHeaderKey("Content-Type") - defaultMarshaler = &HTTPBodyMarshaler{ - Marshaler: &JSONPb{ - MarshalOptions: protojson.MarshalOptions{ - EmitUnpopulated: true, - }, - UnmarshalOptions: protojson.UnmarshalOptions{ - DiscardUnknown: true, - }, + defaultJsonPbMarshaler = &JSONPb{ + MarshalOptions: protojson.MarshalOptions{ + EmitUnpopulated: true, }, + UnmarshalOptions: protojson.UnmarshalOptions{ + DiscardUnknown: true, + }, + } + + defaultMarshaler = &HTTPBodyMarshaler{ + Marshaler: defaultJsonPbMarshaler, + } + + urlEncodedMarshaler = &UrlEncodeMarshal{ + Marshaler: defaultJsonPbMarshaler, } ) @@ -93,7 +102,8 @@ func (m marshalerRegistry) add(mime string, marshaler Marshaler) error { func makeMarshalerMIMERegistry() marshalerRegistry { return marshalerRegistry{ mimeMap: map[string]Marshaler{ - MIMEWildcard: defaultMarshaler, + MIMEWildcard: defaultMarshaler, + MIMEUrlEncoded: urlEncodedMarshaler, }, } }