11package acp
22
33import (
4+ "bytes"
45 "context"
56 "encoding/json"
7+ "errors"
68 "io"
9+ "log/slog"
710 "slices"
11+ "strings"
812 "sync"
913 "sync/atomic"
1014 "testing"
@@ -22,8 +26,12 @@ type clientFuncs struct {
2226 ReleaseTerminalFunc func (context.Context , ReleaseTerminalRequest ) (ReleaseTerminalResponse , error )
2327 TerminalOutputFunc func (context.Context , TerminalOutputRequest ) (TerminalOutputResponse , error )
2428 WaitForTerminalExitFunc func (context.Context , WaitForTerminalExitRequest ) (WaitForTerminalExitResponse , error )
29+
30+ HandleExtensionMethodFunc func (context.Context , string , json.RawMessage ) (any , error )
2531}
2632
33+ var _ ExtensionMethodHandler = (* clientFuncs )(nil )
34+
2735var _ Client = (* clientFuncs )(nil )
2836
2937func (c clientFuncs ) WriteTextFile (ctx context.Context , p WriteTextFileRequest ) (WriteTextFileResponse , error ) {
@@ -94,21 +102,30 @@ func (c *clientFuncs) WaitForTerminalExit(ctx context.Context, params WaitForTer
94102 return WaitForTerminalExitResponse {}, nil
95103}
96104
105+ func (c clientFuncs ) HandleExtensionMethod (ctx context.Context , method string , params json.RawMessage ) (any , error ) {
106+ if c .HandleExtensionMethodFunc != nil {
107+ return c .HandleExtensionMethodFunc (ctx , method , params )
108+ }
109+ return nil , NewMethodNotFound (method )
110+ }
111+
97112type agentFuncs struct {
98- InitializeFunc func (context.Context , InitializeRequest ) (InitializeResponse , error )
99- NewSessionFunc func (context.Context , NewSessionRequest ) (NewSessionResponse , error )
100- LoadSessionFunc func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error )
101- AuthenticateFunc func (context.Context , AuthenticateRequest ) (AuthenticateResponse , error )
102- PromptFunc func (context.Context , PromptRequest ) (PromptResponse , error )
103- CancelFunc func (context.Context , CancelNotification ) error
104- SetSessionModeFunc func (ctx context.Context , params SetSessionModeRequest ) (SetSessionModeResponse , error )
105- SetSessionModelFunc func (ctx context.Context , params SetSessionModelRequest ) (SetSessionModelResponse , error )
113+ InitializeFunc func (context.Context , InitializeRequest ) (InitializeResponse , error )
114+ NewSessionFunc func (context.Context , NewSessionRequest ) (NewSessionResponse , error )
115+ LoadSessionFunc func (context.Context , LoadSessionRequest ) (LoadSessionResponse , error )
116+ AuthenticateFunc func (context.Context , AuthenticateRequest ) (AuthenticateResponse , error )
117+ PromptFunc func (context.Context , PromptRequest ) (PromptResponse , error )
118+ CancelFunc func (context.Context , CancelNotification ) error
119+ SetSessionModeFunc func (ctx context.Context , params SetSessionModeRequest ) (SetSessionModeResponse , error )
120+
121+ HandleExtensionMethodFunc func (context.Context , string , json.RawMessage ) (any , error )
106122}
107123
108124var (
109- _ Agent = (* agentFuncs )(nil )
110- _ AgentLoader = (* agentFuncs )(nil )
111- _ AgentExperimental = (* agentFuncs )(nil )
125+ _ Agent = (* agentFuncs )(nil )
126+ _ AgentLoader = (* agentFuncs )(nil )
127+ _ AgentExperimental = (* agentFuncs )(nil )
128+ _ ExtensionMethodHandler = (* agentFuncs )(nil )
112129)
113130
114131func (a agentFuncs ) Initialize (ctx context.Context , p InitializeRequest ) (InitializeResponse , error ) {
@@ -161,12 +178,11 @@ func (a agentFuncs) SetSessionMode(ctx context.Context, params SetSessionModeReq
161178 return SetSessionModeResponse {}, nil
162179}
163180
164- // SetSessionModel implements AgentExperimental.
165- func (a agentFuncs ) SetSessionModel (ctx context.Context , params SetSessionModelRequest ) (SetSessionModelResponse , error ) {
166- if a .SetSessionModelFunc != nil {
167- return a .SetSessionModelFunc (ctx , params )
181+ func (a agentFuncs ) HandleExtensionMethod (ctx context.Context , method string , params json.RawMessage ) (any , error ) {
182+ if a .HandleExtensionMethodFunc != nil {
183+ return a .HandleExtensionMethodFunc (ctx , method , params )
168184 }
169- return SetSessionModelResponse {}, nil
185+ return nil , NewMethodNotFound ( method )
170186}
171187
172188// Test bidirectional error handling similar to typescript/acp.test.ts
@@ -354,7 +370,7 @@ func TestConnectionHandlesMessageOrdering(t *testing.T) {
354370 }
355371 if _ , err := as .RequestPermission (context .Background (), RequestPermissionRequest {
356372 SessionId : "test-session" ,
357- ToolCall : RequestPermissionToolCall {
373+ ToolCall : ToolCallUpdate {
358374 Title : Ptr ("Execute command" ),
359375 Kind : ptr (ToolKindExecute ),
360376 Status : ptr (ToolCallStatusPending ),
@@ -887,7 +903,7 @@ func TestRequestHandlerCanMakeNestedRequest(t *testing.T) {
887903 PromptFunc : func (ctx context.Context , p PromptRequest ) (PromptResponse , error ) {
888904 _ , err := ag .RequestPermission (ctx , RequestPermissionRequest {
889905 SessionId : p .SessionId ,
890- ToolCall : RequestPermissionToolCall {
906+ ToolCall : ToolCallUpdate {
891907 ToolCallId : "call_1" ,
892908 Title : Ptr ("Test permission" ),
893909 },
@@ -921,3 +937,171 @@ func TestRequestHandlerCanMakeNestedRequest(t *testing.T) {
921937 t .Fatalf ("prompt failed: %v" , err )
922938 }
923939}
940+
941+ type extEchoParams struct {
942+ Msg string `json:"msg"`
943+ }
944+
945+ type extEchoResult struct {
946+ Msg string `json:"msg"`
947+ }
948+
949+ type agentNoExtensions struct {}
950+
951+ func (agentNoExtensions ) Authenticate (ctx context.Context , params AuthenticateRequest ) (AuthenticateResponse , error ) {
952+ return AuthenticateResponse {}, nil
953+ }
954+
955+ func (agentNoExtensions ) Initialize (ctx context.Context , params InitializeRequest ) (InitializeResponse , error ) {
956+ return InitializeResponse {}, nil
957+ }
958+
959+ func (agentNoExtensions ) Cancel (ctx context.Context , params CancelNotification ) error { return nil }
960+
961+ func (agentNoExtensions ) NewSession (ctx context.Context , params NewSessionRequest ) (NewSessionResponse , error ) {
962+ return NewSessionResponse {}, nil
963+ }
964+
965+ func (agentNoExtensions ) Prompt (ctx context.Context , params PromptRequest ) (PromptResponse , error ) {
966+ return PromptResponse {}, nil
967+ }
968+
969+ func (agentNoExtensions ) SetSessionMode (ctx context.Context , params SetSessionModeRequest ) (SetSessionModeResponse , error ) {
970+ return SetSessionModeResponse {}, nil
971+ }
972+
973+ func TestExtensionMethods_ClientToAgentRequest (t * testing.T ) {
974+ c2aR , c2aW := io .Pipe ()
975+ a2cR , a2cW := io .Pipe ()
976+
977+ method := "_vendor.test/echo"
978+
979+ ag := NewAgentSideConnection (agentFuncs {
980+ HandleExtensionMethodFunc : func (ctx context.Context , gotMethod string , params json.RawMessage ) (any , error ) {
981+ if gotMethod != method {
982+ return nil , NewInternalError (map [string ]any {"expected" : method , "got" : gotMethod })
983+ }
984+ var p extEchoParams
985+ if err := json .Unmarshal (params , & p ); err != nil {
986+ return nil , err
987+ }
988+ return extEchoResult {Msg : p .Msg }, nil
989+ },
990+ }, a2cW , c2aR )
991+
992+ _ = ag
993+
994+ c := NewClientSideConnection (& clientFuncs {}, c2aW , a2cR )
995+
996+ ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
997+ defer cancel ()
998+
999+ raw , err := c .CallExtension (ctx , method , extEchoParams {Msg : "hi" })
1000+ if err != nil {
1001+ t .Fatalf ("CallExtension: %v" , err )
1002+ }
1003+ var resp extEchoResult
1004+ if err := json .Unmarshal (raw , & resp ); err != nil {
1005+ t .Fatalf ("unmarshal response: %v" , err )
1006+ }
1007+ if resp .Msg != "hi" {
1008+ t .Fatalf ("unexpected response: %#v" , resp )
1009+ }
1010+ }
1011+
1012+ func TestExtensionMethods_UnknownRequest_ReturnsMethodNotFound (t * testing.T ) {
1013+ c2aR , c2aW := io .Pipe ()
1014+ a2cR , a2cW := io .Pipe ()
1015+
1016+ NewAgentSideConnection (agentNoExtensions {}, a2cW , c2aR )
1017+ c := NewClientSideConnection (& clientFuncs {}, c2aW , a2cR )
1018+
1019+ ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
1020+ defer cancel ()
1021+
1022+ _ , err := c .CallExtension (ctx , "_vendor.test/missing" , extEchoParams {Msg : "hi" })
1023+ if err == nil {
1024+ t .Fatalf ("expected error" )
1025+ }
1026+ var re * RequestError
1027+ if ! errors .As (err , & re ) {
1028+ t .Fatalf ("expected *RequestError, got %T: %v" , err , err )
1029+ }
1030+ if re .Code != - 32601 {
1031+ t .Fatalf ("expected -32601 method not found, got %d" , re .Code )
1032+ }
1033+ }
1034+
1035+ func TestExtensionMethods_UnknownNotification_DoesNotLog (t * testing.T ) {
1036+ c2aR , c2aW := io .Pipe ()
1037+ a2cR , a2cW := io .Pipe ()
1038+
1039+ done := make (chan struct {})
1040+
1041+ ag := NewAgentSideConnection (agentFuncs {
1042+ HandleExtensionMethodFunc : func (ctx context.Context , method string , params json.RawMessage ) (any , error ) {
1043+ close (done )
1044+ return nil , NewMethodNotFound (method )
1045+ },
1046+ }, a2cW , c2aR )
1047+
1048+ var logBuf bytes.Buffer
1049+ ag .SetLogger (slog .New (slog .NewTextHandler (& logBuf , & slog.HandlerOptions {Level : slog .LevelDebug })))
1050+
1051+ c := NewClientSideConnection (& clientFuncs {}, c2aW , a2cR )
1052+
1053+ ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
1054+ defer cancel ()
1055+
1056+ if err := c .NotifyExtension (ctx , "_vendor.test/notify" , map [string ]any {"hello" : "world" }); err != nil {
1057+ t .Fatalf ("NotifyExtension: %v" , err )
1058+ }
1059+
1060+ select {
1061+ case <- done :
1062+ // ok
1063+ case <- ctx .Done ():
1064+ t .Fatalf ("timeout waiting for notification handler" )
1065+ }
1066+
1067+ if strings .Contains (logBuf .String (), "failed to handle notification" ) {
1068+ t .Fatalf ("unexpected notification error log: %s" , logBuf .String ())
1069+ }
1070+ }
1071+
1072+ func TestExtensionMethods_AgentToClientRequest (t * testing.T ) {
1073+ c2aR , c2aW := io .Pipe ()
1074+ a2cR , a2cW := io .Pipe ()
1075+
1076+ method := "_vendor.test/echo"
1077+
1078+ _ = NewClientSideConnection (& clientFuncs {
1079+ HandleExtensionMethodFunc : func (ctx context.Context , gotMethod string , params json.RawMessage ) (any , error ) {
1080+ if gotMethod != method {
1081+ return nil , NewInternalError (map [string ]any {"expected" : method , "got" : gotMethod })
1082+ }
1083+ var p extEchoParams
1084+ if err := json .Unmarshal (params , & p ); err != nil {
1085+ return nil , err
1086+ }
1087+ return extEchoResult {Msg : p .Msg }, nil
1088+ },
1089+ }, c2aW , a2cR )
1090+
1091+ ag := NewAgentSideConnection (agentFuncs {}, a2cW , c2aR )
1092+
1093+ ctx , cancel := context .WithTimeout (context .Background (), 1 * time .Second )
1094+ defer cancel ()
1095+
1096+ raw , err := ag .CallExtension (ctx , method , extEchoParams {Msg : "hi" })
1097+ if err != nil {
1098+ t .Fatalf ("CallExtension: %v" , err )
1099+ }
1100+ var resp extEchoResult
1101+ if err := json .Unmarshal (raw , & resp ); err != nil {
1102+ t .Fatalf ("unmarshal response: %v" , err )
1103+ }
1104+ if resp .Msg != "hi" {
1105+ t .Fatalf ("unexpected response: %#v" , resp )
1106+ }
1107+ }
0 commit comments