Skip to content

Commit 8db8c1a

Browse files
jacksontjachew22
authored andcommitted
Add test for custom marshaler and chunked responses
1 parent cde2f8f commit 8db8c1a

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

runtime/handler_test.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,103 @@ func TestForwardResponseStream(t *testing.T) {
102102
})
103103
}
104104
}
105+
106+
107+
// A custom marshaler implementation, that doesn't implement the delimited interface
108+
type CustomMarshaler struct {
109+
m *runtime.JSONPb
110+
}
111+
func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error) { return c.m.Marshal(v) }
112+
func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error { return c.m.Unmarshal(data, v) }
113+
func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c.m.NewDecoder(r) }
114+
func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) }
115+
func (c *CustomMarshaler) ContentType() string { return c.m.ContentType() }
116+
117+
118+
func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
119+
type msg struct {
120+
pb proto.Message
121+
err error
122+
}
123+
tests := []struct {
124+
name string
125+
msgs []msg
126+
statusCode int
127+
}{{
128+
name: "encoding",
129+
msgs: []msg{
130+
{&pb.SimpleMessage{Id: "One"}, nil},
131+
{&pb.SimpleMessage{Id: "Two"}, nil},
132+
},
133+
statusCode: http.StatusOK,
134+
}, {
135+
name: "empty",
136+
statusCode: http.StatusOK,
137+
}, {
138+
name: "error",
139+
msgs: []msg{{nil, grpc.Errorf(codes.OutOfRange, "400")}},
140+
statusCode: http.StatusBadRequest,
141+
}, {
142+
name: "stream_error",
143+
msgs: []msg{
144+
{&pb.SimpleMessage{Id: "One"}, nil},
145+
{nil, grpc.Errorf(codes.OutOfRange, "400")},
146+
},
147+
statusCode: http.StatusOK,
148+
}}
149+
150+
newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
151+
var count int
152+
return func() (proto.Message, error) {
153+
if count == len(msgs) {
154+
return nil, io.EOF
155+
} else if count > len(msgs) {
156+
t.Errorf("recv() called %d times for %d messages", count, len(msgs))
157+
}
158+
count++
159+
msg := msgs[count-1]
160+
return msg.pb, msg.err
161+
}
162+
}
163+
ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
164+
marshaler := &CustomMarshaler{&runtime.JSONPb{}}
165+
for _, tt := range tests {
166+
t.Run(tt.name, func(t *testing.T) {
167+
recv := newTestRecv(t, tt.msgs)
168+
req := httptest.NewRequest("GET", "http://example.com/foo", nil)
169+
resp := httptest.NewRecorder()
170+
171+
runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
172+
173+
w := resp.Result()
174+
if w.StatusCode != tt.statusCode {
175+
t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
176+
}
177+
if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
178+
t.Errorf("ForwardResponseStream missing header chunked")
179+
}
180+
body, err := ioutil.ReadAll(w.Body)
181+
if err != nil {
182+
t.Errorf("Failed to read response body with %v", err)
183+
}
184+
w.Body.Close()
185+
186+
var want []byte
187+
for _, msg := range tt.msgs {
188+
if msg.err != nil {
189+
t.Skip("checking erorr encodings")
190+
}
191+
b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
192+
if err != nil {
193+
t.Errorf("marshaler.Marshal() failed %v", err)
194+
}
195+
want = append(want, b...)
196+
want = append(want, "\n"...)
197+
}
198+
199+
if string(body) != string(want) {
200+
t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
201+
}
202+
})
203+
}
204+
}

0 commit comments

Comments
 (0)