Skip to content

Commit dad9625

Browse files
committed
use RunAgentInput struct in our golang client
1 parent 5f6bba3 commit dad9625

File tree

3 files changed

+94
-51
lines changed

3 files changed

+94
-51
lines changed

sdks/community/go/pkg/client/sse/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"strings"
1212
"time"
1313

14+
"github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types"
1415
"github.com/sirupsen/logrus"
1516
)
1617

@@ -38,7 +39,7 @@ type Frame struct {
3839

3940
type StreamOptions struct {
4041
Context context.Context
41-
Payload interface{}
42+
Payload types.RunAgentInput
4243
Headers map[string]string
4344
}
4445

sdks/community/go/pkg/client/sse/client_stream_test.go

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,45 +11,54 @@ import (
1111
"testing"
1212
"time"
1313

14+
"github.com/ag-ui-protocol/ag-ui/sdks/community/go/pkg/core/types"
1415
"github.com/sirupsen/logrus"
1516
"github.com/stretchr/testify/assert"
1617
"github.com/stretchr/testify/require"
1718
)
1819

20+
// testPayload returns a simple RunAgentInput for testing
21+
func testPayload() types.RunAgentInput {
22+
return types.RunAgentInput{
23+
ThreadId: "test-thread",
24+
RunId: "test-run",
25+
}
26+
}
27+
1928
func TestStream(t *testing.T) {
2029
t.Run("successful stream", func(t *testing.T) {
2130
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2231
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
2332
assert.Equal(t, "text/event-stream", r.Header.Get("Accept"))
24-
33+
2534
w.Header().Set("Content-Type", "text/event-stream")
2635
w.WriteHeader(http.StatusOK)
27-
36+
2837
flusher, ok := w.(http.Flusher)
2938
require.True(t, ok)
30-
39+
3140
fmt.Fprintf(w, "data: first message\n\n")
3241
flusher.Flush()
33-
42+
3443
fmt.Fprintf(w, "data: second message\n\n")
3544
flusher.Flush()
36-
45+
3746
fmt.Fprintf(w, "data: {\"type\":\"json\",\"value\":123}\n\n")
3847
flusher.Flush()
3948
}))
4049
defer server.Close()
41-
50+
4251
client := NewClient(Config{
4352
Endpoint: server.URL,
4453
BufferSize: 10,
4554
})
46-
55+
4756
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
4857
defer cancel()
49-
58+
5059
frames, errors, err := client.Stream(StreamOptions{
5160
Context: ctx,
52-
Payload: map[string]string{"test": "data"},
61+
Payload: testPayload(),
5362
})
5463
require.NoError(t, err)
5564

@@ -106,10 +115,10 @@ func TestStream(t *testing.T) {
106115

107116
frames, _, err := client.Stream(StreamOptions{
108117
Context: ctx,
109-
Payload: struct{}{},
118+
Payload: testPayload(),
110119
})
111120
require.NoError(t, err)
112-
121+
113122
select {
114123
case frame := <-frames:
115124
assert.Equal(t, "line1\nline2\nline3", string(frame.Data))
@@ -170,13 +179,13 @@ func TestStream(t *testing.T) {
170179

171180
_, _, err := client.Stream(StreamOptions{
172181
Context: ctx,
173-
Payload: struct{}{},
182+
Payload: testPayload(),
174183
})
175184
require.NoError(t, err)
176185
})
177186
}
178187
})
179-
188+
180189
t.Run("custom headers", func(t *testing.T) {
181190
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
182191
assert.Equal(t, "custom-value", r.Header.Get("X-Custom-Header"))
@@ -195,15 +204,15 @@ func TestStream(t *testing.T) {
195204

196205
_, _, err := client.Stream(StreamOptions{
197206
Context: ctx,
198-
Payload: struct{}{},
207+
Payload: testPayload(),
199208
Headers: map[string]string{
200209
"X-Custom-Header": "custom-value",
201210
"X-Another-Header": "another-value",
202211
},
203212
})
204213
require.NoError(t, err)
205214
})
206-
215+
207216
t.Run("error responses", func(t *testing.T) {
208217
tests := []struct {
209218
name string
@@ -250,16 +259,16 @@ func TestStream(t *testing.T) {
250259
client := NewClient(Config{
251260
Endpoint: server.URL,
252261
})
253-
262+
254263
_, _, err := client.Stream(StreamOptions{
255-
Payload: struct{}{},
264+
Payload: testPayload(),
256265
})
257266
require.Error(t, err)
258267
assert.Contains(t, err.Error(), tt.expectedErr)
259268
})
260269
}
261270
})
262-
271+
263272
t.Run("context cancellation", func(t *testing.T) {
264273
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
265274
w.Header().Set("Content-Type", "text/event-stream")
@@ -283,13 +292,13 @@ func TestStream(t *testing.T) {
283292

284293
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
285294
defer cancel()
286-
295+
287296
frames, errors, err := client.Stream(StreamOptions{
288297
Context: ctx,
289-
Payload: struct{}{},
298+
Payload: testPayload(),
290299
})
291300
require.NoError(t, err)
292-
301+
293302
messageCount := 0
294303
for {
295304
select {
@@ -309,32 +318,17 @@ func TestStream(t *testing.T) {
309318
}
310319
})
311320

312-
t.Run("invalid payload marshaling", func(t *testing.T) {
313-
client := NewClient(Config{
314-
Endpoint: "http://localhost",
315-
})
316-
317-
// Create an unmarshalable payload
318-
invalidPayload := make(chan int)
319-
320-
_, _, err := client.Stream(StreamOptions{
321-
Payload: invalidPayload,
322-
})
323-
require.Error(t, err)
324-
assert.Contains(t, err.Error(), "failed to marshal payload")
325-
})
326-
327321
t.Run("invalid endpoint", func(t *testing.T) {
328322
client := NewClient(Config{
329323
Endpoint: "http://[::1]:namedport", // Invalid URL
330324
})
331-
325+
332326
_, _, err := client.Stream(StreamOptions{
333-
Payload: struct{}{},
327+
Payload: testPayload(),
334328
})
335329
require.Error(t, err)
336330
})
337-
331+
338332
t.Run("concurrent reads", func(t *testing.T) {
339333
messageCount := 50
340334
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -358,13 +352,13 @@ func TestStream(t *testing.T) {
358352

359353
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
360354
defer cancel()
361-
355+
362356
frames, _, err := client.Stream(StreamOptions{
363357
Context: ctx,
364-
Payload: struct{}{},
358+
Payload: testPayload(),
365359
})
366360
require.NoError(t, err)
367-
361+
368362
var wg sync.WaitGroup
369363
received := make(map[string]bool)
370364
mu := sync.Mutex{}
@@ -410,13 +404,13 @@ func TestStream(t *testing.T) {
410404

411405
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
412406
defer cancel()
413-
407+
414408
frames, errors, err := client.Stream(StreamOptions{
415409
Context: ctx,
416-
Payload: struct{}{},
410+
Payload: testPayload(),
417411
})
418412
require.NoError(t, err)
419-
413+
420414
// Should receive initial message
421415
select {
422416
case frame := <-frames:
@@ -463,13 +457,13 @@ func TestStream(t *testing.T) {
463457

464458
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
465459
defer cancel()
466-
460+
467461
frames, _, err := client.Stream(StreamOptions{
468462
Context: ctx,
469-
Payload: struct{}{},
463+
Payload: testPayload(),
470464
})
471465
require.NoError(t, err)
472-
466+
473467
// Consume all frames
474468
go func() {
475469
for range frames {
@@ -691,12 +685,12 @@ func BenchmarkStream(b *testing.B) {
691685

692686
frames, _, err := client.Stream(StreamOptions{
693687
Context: ctx,
694-
Payload: struct{}{},
688+
Payload: testPayload(),
695689
})
696690
if err != nil {
697691
b.Fatal(err)
698692
}
699-
693+
700694
count := 0
701695
for range frames {
702696
count++
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package types
2+
3+
// Context represents additional context provided to the agent
4+
type Context struct {
5+
Description string `json:"description"`
6+
Value string `json:"value"`
7+
}
8+
9+
// Tool represents a tool available to the agent
10+
type Tool struct {
11+
Name string `json:"name"`
12+
Description string `json:"description"`
13+
Parameters any `json:"parameters"` // JSON Schema for the tool parameters
14+
}
15+
16+
// RunAgentInput represents the input payload for running an agent
17+
type RunAgentInput struct {
18+
ThreadId string `json:"threadId"`
19+
RunId string `json:"runId"`
20+
State any `json:"state,omitempty"`
21+
Messages []Message `json:"messages"`
22+
Tools []Tool `json:"tools,omitempty"`
23+
Context []Context `json:"context,omitempty"`
24+
ForwardedProps any `json:"forwardedProps,omitempty"`
25+
}
26+
27+
// Message represents a message in the conversation
28+
type Message struct {
29+
ID string `json:"id"`
30+
Role string `json:"role"`
31+
Content *string `json:"content,omitempty"`
32+
Name *string `json:"name,omitempty"`
33+
ToolCalls []ToolCall `json:"toolCalls,omitempty"`
34+
ToolCallID *string `json:"toolCallId,omitempty"`
35+
}
36+
37+
// ToolCall represents a tool call within a message
38+
type ToolCall struct {
39+
ID string `json:"id"`
40+
Type string `json:"type"`
41+
Function Function `json:"function"`
42+
}
43+
44+
// Function represents a function call
45+
type Function struct {
46+
Name string `json:"name"`
47+
Arguments string `json:"arguments"`
48+
}

0 commit comments

Comments
 (0)