@@ -102,3 +102,103 @@ func TestForwardResponseStream(t *testing.T) {
102102 })
103103 }
104104}
105+
106+
107+ // A custom marshaler implementation, that doesn't implement the delimited interface
108+ type CustomMarshaler struct {
109+ m * runtime.JSONPb
110+ }
111+ func (c * CustomMarshaler ) Marshal (v interface {}) ([]byte , error ) { return c .m .Marshal (v ) }
112+ func (c * CustomMarshaler ) Unmarshal (data []byte , v interface {}) error { return c .m .Unmarshal (data , v ) }
113+ func (c * CustomMarshaler ) NewDecoder (r io.Reader ) runtime.Decoder { return c .m .NewDecoder (r ) }
114+ func (c * CustomMarshaler ) NewEncoder (w io.Writer ) runtime.Encoder { return c .m .NewEncoder (w ) }
115+ func (c * CustomMarshaler ) ContentType () string { return c .m .ContentType () }
116+
117+
118+ func TestForwardResponseStreamCustomMarshaler (t * testing.T ) {
119+ type msg struct {
120+ pb proto.Message
121+ err error
122+ }
123+ tests := []struct {
124+ name string
125+ msgs []msg
126+ statusCode int
127+ }{{
128+ name : "encoding" ,
129+ msgs : []msg {
130+ {& pb.SimpleMessage {Id : "One" }, nil },
131+ {& pb.SimpleMessage {Id : "Two" }, nil },
132+ },
133+ statusCode : http .StatusOK ,
134+ }, {
135+ name : "empty" ,
136+ statusCode : http .StatusOK ,
137+ }, {
138+ name : "error" ,
139+ msgs : []msg {{nil , grpc .Errorf (codes .OutOfRange , "400" )}},
140+ statusCode : http .StatusBadRequest ,
141+ }, {
142+ name : "stream_error" ,
143+ msgs : []msg {
144+ {& pb.SimpleMessage {Id : "One" }, nil },
145+ {nil , grpc .Errorf (codes .OutOfRange , "400" )},
146+ },
147+ statusCode : http .StatusOK ,
148+ }}
149+
150+ newTestRecv := func (t * testing.T , msgs []msg ) func () (proto.Message , error ) {
151+ var count int
152+ return func () (proto.Message , error ) {
153+ if count == len (msgs ) {
154+ return nil , io .EOF
155+ } else if count > len (msgs ) {
156+ t .Errorf ("recv() called %d times for %d messages" , count , len (msgs ))
157+ }
158+ count ++
159+ msg := msgs [count - 1 ]
160+ return msg .pb , msg .err
161+ }
162+ }
163+ ctx := runtime .NewServerMetadataContext (context .Background (), runtime.ServerMetadata {})
164+ marshaler := & CustomMarshaler {& runtime.JSONPb {}}
165+ for _ , tt := range tests {
166+ t .Run (tt .name , func (t * testing.T ) {
167+ recv := newTestRecv (t , tt .msgs )
168+ req := httptest .NewRequest ("GET" , "http://example.com/foo" , nil )
169+ resp := httptest .NewRecorder ()
170+
171+ runtime .ForwardResponseStream (ctx , runtime .NewServeMux (), marshaler , resp , req , recv )
172+
173+ w := resp .Result ()
174+ if w .StatusCode != tt .statusCode {
175+ t .Errorf ("StatusCode %d want %d" , w .StatusCode , tt .statusCode )
176+ }
177+ if h := w .Header .Get ("Transfer-Encoding" ); h != "chunked" {
178+ t .Errorf ("ForwardResponseStream missing header chunked" )
179+ }
180+ body , err := ioutil .ReadAll (w .Body )
181+ if err != nil {
182+ t .Errorf ("Failed to read response body with %v" , err )
183+ }
184+ w .Body .Close ()
185+
186+ var want []byte
187+ for _ , msg := range tt .msgs {
188+ if msg .err != nil {
189+ t .Skip ("checking erorr encodings" )
190+ }
191+ b , err := marshaler .Marshal (map [string ]proto.Message {"result" : msg .pb })
192+ if err != nil {
193+ t .Errorf ("marshaler.Marshal() failed %v" , err )
194+ }
195+ want = append (want , b ... )
196+ want = append (want , "\n " ... )
197+ }
198+
199+ if string (body ) != string (want ) {
200+ t .Errorf ("ForwardResponseStream() = \" %s\" want \" %s\" " , body , want )
201+ }
202+ })
203+ }
204+ }
0 commit comments