@@ -3,12 +3,16 @@ package actions
33import (
44 "context"
55 "fmt"
6+ "runtime"
67 "testing"
8+ "time"
79
810 config "github.com/conductorone/baton-sdk/pb/c1/config/v1"
911 v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
1012 "github.com/conductorone/baton-sdk/pkg/annotations"
1113 "github.com/stretchr/testify/require"
14+ "google.golang.org/grpc/codes"
15+ "google.golang.org/grpc/status"
1216 "google.golang.org/protobuf/types/known/structpb"
1317)
1418
@@ -47,6 +51,74 @@ func testActionHandler(ctx context.Context, args *structpb.Struct) (*structpb.St
4751 return & userStruct , nil , nil
4852}
4953
54+ func testAsyncActionHandler (ctx context.Context , args * structpb.Struct ) (* structpb.Struct , annotations.Annotations , error ) {
55+ _ , ok := args .Fields ["dn" ].GetKind ().(* structpb.Value_StringValue )
56+ if ! ok {
57+ return nil , nil , fmt .Errorf ("missing dn" )
58+ }
59+
60+ for i := 0 ; i < 12 ; i ++ {
61+ select {
62+ case <- ctx .Done ():
63+ return nil , nil , status .Error (codes .Canceled , "context canceled" )
64+ default :
65+ time .Sleep (100 * time .Millisecond )
66+ }
67+ }
68+
69+ var userStruct structpb.Struct = structpb.Struct {
70+ Fields : map [string ]* structpb.Value {
71+ "success" : {
72+ Kind : & structpb.Value_BoolValue {BoolValue : true },
73+ },
74+ },
75+ }
76+ return & userStruct , nil , nil
77+ }
78+
79+ var testInput = & structpb.Struct {
80+ Fields : map [string ]* structpb.Value {
81+ "dn" : {
82+ Kind : & structpb.Value_StringValue {StringValue : "test" },
83+ },
84+ },
85+ }
86+
87+ func testAsyncCancelActionHandler (ctx context.Context , args * structpb.Struct ) (* structpb.Struct , annotations.Annotations , error ) {
88+ _ , ok := args .Fields ["dn" ].GetKind ().(* structpb.Value_StringValue )
89+ if ! ok {
90+ return nil , nil , fmt .Errorf ("missing dn" )
91+ }
92+
93+ // Create a child context that we'll cancel after a short delay
94+ childCtx , cancel := context .WithCancel (ctx )
95+ defer cancel ()
96+
97+ // Start a goroutine to cancel after a short delay
98+ go func () {
99+ time .Sleep (100 * time .Millisecond )
100+ cancel ()
101+ }()
102+
103+ for i := 0 ; i < 12 ; i ++ {
104+ select {
105+ case <- childCtx .Done ():
106+ return nil , nil , status .Error (codes .Canceled , "context canceled" )
107+ default :
108+ time .Sleep (100 * time .Millisecond )
109+ }
110+ }
111+
112+ var userStruct structpb.Struct = structpb.Struct {
113+ Fields : map [string ]* structpb.Value {
114+ "success" : {
115+ Kind : & structpb.Value_BoolValue {BoolValue : true },
116+ },
117+ },
118+ }
119+ return & userStruct , nil , nil
120+ }
121+
50122func TestActionHandler (t * testing.T ) {
51123 ctx := context .Background ()
52124 m := NewActionManager (ctx )
@@ -64,24 +136,126 @@ func TestActionHandler(t *testing.T) {
64136 require .NoError (t , err )
65137 require .Equal (t , testActionSchema , schema )
66138
67- _ , status , returnArgs , _ , err := m .InvokeAction (ctx , "lock_account" , & structpb.Struct {
68- Fields : map [string ]* structpb.Value {
69- "dn" : {
70- Kind : & structpb.Value_StringValue {StringValue : "test" },
71- },
72- },
73- })
139+ _ , status , returnArgs , _ , err := m .InvokeAction (ctx , "lock_account" , testInput )
74140 require .NoError (t , err )
75141 require .Equal (t , v2 .BatonActionStatus_BATON_ACTION_STATUS_COMPLETE , status )
76142 require .NotNil (t , returnArgs )
77143 success , ok := returnArgs .Fields ["success" ].GetKind ().(* structpb.Value_BoolValue )
78144 require .True (t , ok )
79145 require .True (t , success .BoolValue )
80146
81- _ , status , _ , _ , err = m .InvokeAction (ctx , "lock_account" , & structpb.Struct {
147+ _ , status , rv , _ , err : = m .InvokeAction (ctx , "lock_account" , & structpb.Struct {
82148 Fields : map [string ]* structpb.Value {},
83149 })
84-
85- require .Error (t , err )
150+ expectedRv := & structpb.Struct {
151+ Fields : map [string ]* structpb.Value {
152+ "error" : {
153+ Kind : & structpb.Value_StringValue {StringValue : "missing dn" },
154+ },
155+ },
156+ }
157+ require .NoError (t , err )
86158 require .Equal (t , v2 .BatonActionStatus_BATON_ACTION_STATUS_FAILED , status )
159+ require .Equal (t , expectedRv , rv )
160+ }
161+
162+ func TestAsyncActionHandler (t * testing.T ) {
163+ ctx := context .Background ()
164+ m := NewActionManager (ctx )
165+ require .NotNil (t , m )
166+
167+ err := m .RegisterAction (ctx , "lock_account" , testActionSchema , testAsyncActionHandler )
168+ require .NoError (t , err )
169+
170+ schemas , _ , err := m .ListActionSchemas (ctx )
171+ require .NoError (t , err )
172+ require .Len (t , schemas , 1 )
173+ require .Equal (t , testActionSchema , schemas [0 ])
174+
175+ schema , _ , err := m .GetActionSchema (ctx , "lock_account" )
176+ require .NoError (t , err )
177+ require .Equal (t , testActionSchema , schema )
178+
179+ actionId , status , rv , _ , err := m .InvokeAction (ctx , "lock_account" , testInput )
180+ require .NoError (t , err )
181+ require .Equal (t , v2 .BatonActionStatus_BATON_ACTION_STATUS_RUNNING , status )
182+ require .Nil (t , rv )
183+
184+ status , name , _ , _ , err := m .GetActionStatus (ctx , actionId )
185+ require .NoError (t , err )
186+ require .Equal (t , "lock_account" , name )
187+ require .Equal (t , v2 .BatonActionStatus_BATON_ACTION_STATUS_RUNNING , status )
188+
189+ time .Sleep (1 * time .Second )
190+
191+ status , name , rv , _ , err = m .GetActionStatus (ctx , actionId )
192+ require .NoError (t , err )
193+ require .Equal (t , "lock_account" , name )
194+ require .Equal (t , v2 .BatonActionStatus_BATON_ACTION_STATUS_COMPLETE , status )
195+ require .NotNil (t , rv )
196+ success , ok := rv .Fields ["success" ].GetKind ().(* structpb.Value_BoolValue )
197+ require .True (t , ok )
198+ require .True (t , success .BoolValue )
199+ }
200+
201+ func TestActionHandlerGoroutineLeaks (t * testing.T ) {
202+ // Test case 1: Normal completion should not leak goroutines
203+ t .Run ("normal completion" , func (t * testing.T ) {
204+ ctx := context .Background ()
205+ m := NewActionManager (ctx )
206+ require .NotNil (t , m )
207+
208+ err := m .RegisterAction (ctx , "lock_account" , testActionSchema , testAsyncActionHandler )
209+ require .NoError (t , err )
210+
211+ // Get initial goroutine count
212+ initialCount := runtime .NumGoroutine ()
213+
214+ actionId , status , _ , _ , err := m .InvokeAction (ctx , "lock_account" , testInput )
215+ require .NoError (t , err )
216+ require .Equal (t , v2 .BatonActionStatus_BATON_ACTION_STATUS_RUNNING , status )
217+
218+ // Wait for completion
219+ time .Sleep (1 * time .Second )
220+
221+ // Check final status
222+ status , name , _ , _ , err := m .GetActionStatus (ctx , actionId )
223+ require .NoError (t , err )
224+ require .Equal (t , "lock_account" , name )
225+ require .Equal (t , v2 .BatonActionStatus_BATON_ACTION_STATUS_COMPLETE , status )
226+
227+ // Give a small grace period for goroutines to clean up
228+ time .Sleep (100 * time .Millisecond )
229+
230+ // Verify no goroutine leaks
231+ finalCount := runtime .NumGoroutine ()
232+ require .LessOrEqual (t , finalCount , initialCount + 1 , "goroutine leak detected after normal completion" )
233+ })
234+
235+ // Test case 2: Cancelled context should not leak goroutines
236+ t .Run ("context cancellation" , func (t * testing.T ) {
237+ ctx := context .Background ()
238+ m := NewActionManager (ctx )
239+ require .NotNil (t , m )
240+
241+ err := m .RegisterAction (ctx , "lock_account" , testActionSchema , testAsyncCancelActionHandler )
242+ require .NoError (t , err )
243+
244+ // Get initial goroutine count
245+ initialCount := runtime .NumGoroutine ()
246+
247+ _ , status , rv , _ , err := m .InvokeAction (ctx , "lock_account" , testInput )
248+ require .NoError (t , err )
249+ require .Equal (t , v2 .BatonActionStatus_BATON_ACTION_STATUS_FAILED , status )
250+
251+ errMsg := rv .Fields ["error" ].GetKind ().(* structpb.Value_StringValue ).StringValue
252+ require .Contains (t , errMsg , "context canceled" )
253+
254+ // Give a small grace period for goroutines to clean up
255+ time .Sleep (100 * time .Millisecond )
256+
257+ // Verify no goroutine leaks
258+ finalCount := runtime .NumGoroutine ()
259+ require .LessOrEqual (t , finalCount , initialCount + 1 , "goroutine leak detected after context cancellation" )
260+ })
87261}
0 commit comments