@@ -11,45 +11,54 @@ import (
1111 "testing"
1212 "time"
1313
14+ "github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types"
1415 "github.com/sirupsen/logrus"
1516 "github.com/stretchr/testify/assert"
1617 "github.com/stretchr/testify/require"
1718)
1819
20+ // testPayload returns a simple RunAgentInput for testing
21+ func testPayload () types.RunAgentInput {
22+ return types.RunAgentInput {
23+ ThreadId : "test-thread" ,
24+ RunId : "test-run" ,
25+ }
26+ }
27+
1928func TestStream (t * testing.T ) {
2029 t .Run ("successful stream" , func (t * testing.T ) {
2130 server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
2231 assert .Equal (t , "application/json" , r .Header .Get ("Content-Type" ))
2332 assert .Equal (t , "text/event-stream" , r .Header .Get ("Accept" ))
24-
33+
2534 w .Header ().Set ("Content-Type" , "text/event-stream" )
2635 w .WriteHeader (http .StatusOK )
27-
36+
2837 flusher , ok := w .(http.Flusher )
2938 require .True (t , ok )
30-
39+
3140 fmt .Fprintf (w , "data: first message\n \n " )
3241 flusher .Flush ()
33-
42+
3443 fmt .Fprintf (w , "data: second message\n \n " )
3544 flusher .Flush ()
36-
45+
3746 fmt .Fprintf (w , "data: {\" type\" :\" json\" ,\" value\" :123}\n \n " )
3847 flusher .Flush ()
3948 }))
4049 defer server .Close ()
41-
50+
4251 client := NewClient (Config {
4352 Endpoint : server .URL ,
4453 BufferSize : 10 ,
4554 })
46-
55+
4756 ctx , cancel := context .WithTimeout (context .Background (), 2 * time .Second )
4857 defer cancel ()
49-
58+
5059 frames , errors , err := client .Stream (StreamOptions {
5160 Context : ctx ,
52- Payload : map [ string ] string { "test" : "data" } ,
61+ Payload : testPayload () ,
5362 })
5463 require .NoError (t , err )
5564
@@ -106,10 +115,10 @@ func TestStream(t *testing.T) {
106115
107116 frames , _ , err := client .Stream (StreamOptions {
108117 Context : ctx ,
109- Payload : struct {}{} ,
118+ Payload : testPayload () ,
110119 })
111120 require .NoError (t , err )
112-
121+
113122 select {
114123 case frame := <- frames :
115124 assert .Equal (t , "line1\n line2\n line3" , string (frame .Data ))
@@ -170,13 +179,13 @@ func TestStream(t *testing.T) {
170179
171180 _ , _ , err := client .Stream (StreamOptions {
172181 Context : ctx ,
173- Payload : struct {}{} ,
182+ Payload : testPayload () ,
174183 })
175184 require .NoError (t , err )
176185 })
177186 }
178187 })
179-
188+
180189 t .Run ("custom headers" , func (t * testing.T ) {
181190 server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
182191 assert .Equal (t , "custom-value" , r .Header .Get ("X-Custom-Header" ))
@@ -195,15 +204,15 @@ func TestStream(t *testing.T) {
195204
196205 _ , _ , err := client .Stream (StreamOptions {
197206 Context : ctx ,
198- Payload : struct {}{} ,
207+ Payload : testPayload () ,
199208 Headers : map [string ]string {
200209 "X-Custom-Header" : "custom-value" ,
201210 "X-Another-Header" : "another-value" ,
202211 },
203212 })
204213 require .NoError (t , err )
205214 })
206-
215+
207216 t .Run ("error responses" , func (t * testing.T ) {
208217 tests := []struct {
209218 name string
@@ -250,16 +259,16 @@ func TestStream(t *testing.T) {
250259 client := NewClient (Config {
251260 Endpoint : server .URL ,
252261 })
253-
262+
254263 _ , _ , err := client .Stream (StreamOptions {
255- Payload : struct {}{} ,
264+ Payload : testPayload () ,
256265 })
257266 require .Error (t , err )
258267 assert .Contains (t , err .Error (), tt .expectedErr )
259268 })
260269 }
261270 })
262-
271+
263272 t .Run ("context cancellation" , func (t * testing.T ) {
264273 server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
265274 w .Header ().Set ("Content-Type" , "text/event-stream" )
@@ -283,13 +292,13 @@ func TestStream(t *testing.T) {
283292
284293 ctx , cancel := context .WithTimeout (context .Background (), 200 * time .Millisecond )
285294 defer cancel ()
286-
295+
287296 frames , errors , err := client .Stream (StreamOptions {
288297 Context : ctx ,
289- Payload : struct {}{} ,
298+ Payload : testPayload () ,
290299 })
291300 require .NoError (t , err )
292-
301+
293302 messageCount := 0
294303 for {
295304 select {
@@ -309,32 +318,17 @@ func TestStream(t *testing.T) {
309318 }
310319 })
311320
312- t .Run ("invalid payload marshaling" , func (t * testing.T ) {
313- client := NewClient (Config {
314- Endpoint : "http://localhost" ,
315- })
316-
317- // Create an unmarshalable payload
318- invalidPayload := make (chan int )
319-
320- _ , _ , err := client .Stream (StreamOptions {
321- Payload : invalidPayload ,
322- })
323- require .Error (t , err )
324- assert .Contains (t , err .Error (), "failed to marshal payload" )
325- })
326-
327321 t .Run ("invalid endpoint" , func (t * testing.T ) {
328322 client := NewClient (Config {
329323 Endpoint : "http://[::1]:namedport" , // Invalid URL
330324 })
331-
325+
332326 _ , _ , err := client .Stream (StreamOptions {
333- Payload : struct {}{} ,
327+ Payload : testPayload () ,
334328 })
335329 require .Error (t , err )
336330 })
337-
331+
338332 t .Run ("concurrent reads" , func (t * testing.T ) {
339333 messageCount := 50
340334 server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
@@ -358,13 +352,13 @@ func TestStream(t *testing.T) {
358352
359353 ctx , cancel := context .WithTimeout (context .Background (), 2 * time .Second )
360354 defer cancel ()
361-
355+
362356 frames , _ , err := client .Stream (StreamOptions {
363357 Context : ctx ,
364- Payload : struct {}{} ,
358+ Payload : testPayload () ,
365359 })
366360 require .NoError (t , err )
367-
361+
368362 var wg sync.WaitGroup
369363 received := make (map [string ]bool )
370364 mu := sync.Mutex {}
@@ -410,13 +404,13 @@ func TestStream(t *testing.T) {
410404
411405 ctx , cancel := context .WithTimeout (context .Background (), 2 * time .Second )
412406 defer cancel ()
413-
407+
414408 frames , errors , err := client .Stream (StreamOptions {
415409 Context : ctx ,
416- Payload : struct {}{} ,
410+ Payload : testPayload () ,
417411 })
418412 require .NoError (t , err )
419-
413+
420414 // Should receive initial message
421415 select {
422416 case frame := <- frames :
@@ -463,13 +457,13 @@ func TestStream(t *testing.T) {
463457
464458 ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
465459 defer cancel ()
466-
460+
467461 frames , _ , err := client .Stream (StreamOptions {
468462 Context : ctx ,
469- Payload : struct {}{} ,
463+ Payload : testPayload () ,
470464 })
471465 require .NoError (t , err )
472-
466+
473467 // Consume all frames
474468 go func () {
475469 for range frames {
@@ -691,12 +685,12 @@ func BenchmarkStream(b *testing.B) {
691685
692686 frames , _ , err := client .Stream (StreamOptions {
693687 Context : ctx ,
694- Payload : struct {}{} ,
688+ Payload : testPayload () ,
695689 })
696690 if err != nil {
697691 b .Fatal (err )
698692 }
699-
693+
700694 count := 0
701695 for range frames {
702696 count ++
0 commit comments