Skip to content

Commit 5c30db8

Browse files
authored
If an action handler errors, set the error in the return value. (#314)
Returning an error in InvokeAction or GetStatus reflects an internal error, not a handler erroring.
1 parent c830e16 commit 5c30db8

File tree

6 files changed

+337
-123
lines changed

6 files changed

+337
-123
lines changed

pb/c1/connector/v2/action.pb.go

Lines changed: 109 additions & 97 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pb/c1/connector/v2/action.pb.validate.go

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/actions/actions.go

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,28 @@ func (oa *OutstandingAction) SetStatus(ctx context.Context, status v2.BatonActio
5858
oa.Status = status
5959
}
6060

61+
func (oa *OutstandingAction) setError(_ context.Context, err error) {
62+
oa.Mutex.Lock()
63+
defer oa.Mutex.Unlock()
64+
if oa.Rv == nil {
65+
oa.Rv = &structpb.Struct{}
66+
}
67+
if oa.Rv.Fields == nil {
68+
oa.Rv.Fields = make(map[string]*structpb.Value)
69+
}
70+
oa.Rv.Fields["error"] = &structpb.Value{
71+
Kind: &structpb.Value_StringValue{
72+
StringValue: err.Error(),
73+
},
74+
}
75+
oa.Err = err
76+
}
77+
78+
func (oa *OutstandingAction) SetError(ctx context.Context, err error) {
79+
oa.setError(ctx, err)
80+
oa.SetStatus(ctx, v2.BatonActionStatus_BATON_ACTION_STATUS_FAILED)
81+
}
82+
6183
const maxOldActions = 1000
6284

6385
type ActionManager struct {
@@ -185,14 +207,15 @@ func (a *ActionManager) GetActionSchema(ctx context.Context, name string) (*v2.B
185207
return schema, nil, nil
186208
}
187209

188-
func (a *ActionManager) GetActionStatus(ctx context.Context, actionId string) (v2.BatonActionStatus, *structpb.Struct, annotations.Annotations, error) {
210+
func (a *ActionManager) GetActionStatus(ctx context.Context, actionId string) (v2.BatonActionStatus, string, *structpb.Struct, annotations.Annotations, error) {
189211
oa := a.actions[actionId]
190212
if oa == nil {
191-
return v2.BatonActionStatus_BATON_ACTION_STATUS_UNKNOWN, nil, nil, status.Error(codes.NotFound, fmt.Sprintf("action id %s not found", actionId))
213+
return v2.BatonActionStatus_BATON_ACTION_STATUS_UNKNOWN, "", nil, nil, status.Error(codes.NotFound, fmt.Sprintf("action id %s not found", actionId))
192214
}
193215

194216
// Don't return oa.Err here because error is for GetActionStatus, not the action itself.
195-
return oa.Status, oa.Rv, oa.Annos, nil
217+
// oa.Rv contains any error.
218+
return oa.Status, oa.Name, oa.Rv, oa.Annos, nil
196219
}
197220

198221
func (a *ActionManager) InvokeAction(ctx context.Context, name string, args *structpb.Struct) (string, v2.BatonActionStatus, *structpb.Struct, annotations.Annotations, error) {
@@ -206,28 +229,29 @@ func (a *ActionManager) InvokeAction(ctx context.Context, name string, args *str
206229
done := make(chan struct{})
207230

208231
// If handler exits within a second, return result.
209-
// If handler takes longer than 10 seconds, return status pending.
232+
// If handler takes longer than 1 second, return status pending.
210233
// If handler takes longer than an hour, return status failed.
211234
go func() {
212235
oa.SetStatus(ctx, v2.BatonActionStatus_BATON_ACTION_STATUS_RUNNING)
213236
handlerCtx, cancel := context.WithTimeoutCause(ctx, 1*time.Hour, errors.New("action handler timed out"))
214237
defer cancel()
215-
oa.Rv, oa.Annos, oa.Err = handler(handlerCtx, args)
216-
if oa.Err == nil {
238+
var oaErr error
239+
oa.Rv, oa.Annos, oaErr = handler(handlerCtx, args)
240+
if oaErr == nil {
217241
oa.SetStatus(ctx, v2.BatonActionStatus_BATON_ACTION_STATUS_COMPLETE)
218242
} else {
219-
oa.SetStatus(ctx, v2.BatonActionStatus_BATON_ACTION_STATUS_FAILED)
243+
oa.SetError(ctx, oaErr)
220244
}
221245
done <- struct{}{}
222246
}()
223247

224248
select {
225249
case <-done:
226-
return oa.Id, oa.Status, oa.Rv, oa.Annos, oa.Err
227-
case <-time.After(10 * time.Second):
228-
return oa.Id, oa.Status, oa.Rv, oa.Annos, oa.Err
250+
return oa.Id, oa.Status, oa.Rv, oa.Annos, nil
251+
case <-time.After(1 * time.Second):
252+
return oa.Id, oa.Status, oa.Rv, oa.Annos, nil
229253
case <-ctx.Done():
230-
oa.SetStatus(ctx, v2.BatonActionStatus_BATON_ACTION_STATUS_FAILED)
231-
return oa.Id, oa.Status, nil, nil, ctx.Err()
254+
oa.SetError(ctx, ctx.Err())
255+
return oa.Id, oa.Status, oa.Rv, oa.Annos, ctx.Err()
232256
}
233257
}

pkg/actions/actions_test.go

Lines changed: 184 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@ package actions
33
import (
44
"context"
55
"fmt"
6+
"runtime"
67
"testing"
8+
"time"
79

810
config "github.com/conductorone/baton-sdk/pb/c1/config/v1"
911
v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
1012
"github.com/conductorone/baton-sdk/pkg/annotations"
1113
"github.com/stretchr/testify/require"
14+
"google.golang.org/grpc/codes"
15+
"google.golang.org/grpc/status"
1216
"google.golang.org/protobuf/types/known/structpb"
1317
)
1418

@@ -47,6 +51,74 @@ func testActionHandler(ctx context.Context, args *structpb.Struct) (*structpb.St
4751
return &userStruct, nil, nil
4852
}
4953

54+
func testAsyncActionHandler(ctx context.Context, args *structpb.Struct) (*structpb.Struct, annotations.Annotations, error) {
55+
_, ok := args.Fields["dn"].GetKind().(*structpb.Value_StringValue)
56+
if !ok {
57+
return nil, nil, fmt.Errorf("missing dn")
58+
}
59+
60+
for i := 0; i < 12; i++ {
61+
select {
62+
case <-ctx.Done():
63+
return nil, nil, status.Error(codes.Canceled, "context canceled")
64+
default:
65+
time.Sleep(100 * time.Millisecond)
66+
}
67+
}
68+
69+
var userStruct structpb.Struct = structpb.Struct{
70+
Fields: map[string]*structpb.Value{
71+
"success": {
72+
Kind: &structpb.Value_BoolValue{BoolValue: true},
73+
},
74+
},
75+
}
76+
return &userStruct, nil, nil
77+
}
78+
79+
var testInput = &structpb.Struct{
80+
Fields: map[string]*structpb.Value{
81+
"dn": {
82+
Kind: &structpb.Value_StringValue{StringValue: "test"},
83+
},
84+
},
85+
}
86+
87+
func testAsyncCancelActionHandler(ctx context.Context, args *structpb.Struct) (*structpb.Struct, annotations.Annotations, error) {
88+
_, ok := args.Fields["dn"].GetKind().(*structpb.Value_StringValue)
89+
if !ok {
90+
return nil, nil, fmt.Errorf("missing dn")
91+
}
92+
93+
// Create a child context that we'll cancel after a short delay
94+
childCtx, cancel := context.WithCancel(ctx)
95+
defer cancel()
96+
97+
// Start a goroutine to cancel after a short delay
98+
go func() {
99+
time.Sleep(100 * time.Millisecond)
100+
cancel()
101+
}()
102+
103+
for i := 0; i < 12; i++ {
104+
select {
105+
case <-childCtx.Done():
106+
return nil, nil, status.Error(codes.Canceled, "context canceled")
107+
default:
108+
time.Sleep(100 * time.Millisecond)
109+
}
110+
}
111+
112+
var userStruct structpb.Struct = structpb.Struct{
113+
Fields: map[string]*structpb.Value{
114+
"success": {
115+
Kind: &structpb.Value_BoolValue{BoolValue: true},
116+
},
117+
},
118+
}
119+
return &userStruct, nil, nil
120+
}
121+
50122
func TestActionHandler(t *testing.T) {
51123
ctx := context.Background()
52124
m := NewActionManager(ctx)
@@ -64,24 +136,126 @@ func TestActionHandler(t *testing.T) {
64136
require.NoError(t, err)
65137
require.Equal(t, testActionSchema, schema)
66138

67-
_, status, returnArgs, _, err := m.InvokeAction(ctx, "lock_account", &structpb.Struct{
68-
Fields: map[string]*structpb.Value{
69-
"dn": {
70-
Kind: &structpb.Value_StringValue{StringValue: "test"},
71-
},
72-
},
73-
})
139+
_, status, returnArgs, _, err := m.InvokeAction(ctx, "lock_account", testInput)
74140
require.NoError(t, err)
75141
require.Equal(t, v2.BatonActionStatus_BATON_ACTION_STATUS_COMPLETE, status)
76142
require.NotNil(t, returnArgs)
77143
success, ok := returnArgs.Fields["success"].GetKind().(*structpb.Value_BoolValue)
78144
require.True(t, ok)
79145
require.True(t, success.BoolValue)
80146

81-
_, status, _, _, err = m.InvokeAction(ctx, "lock_account", &structpb.Struct{
147+
_, status, rv, _, err := m.InvokeAction(ctx, "lock_account", &structpb.Struct{
82148
Fields: map[string]*structpb.Value{},
83149
})
84-
85-
require.Error(t, err)
150+
expectedRv := &structpb.Struct{
151+
Fields: map[string]*structpb.Value{
152+
"error": {
153+
Kind: &structpb.Value_StringValue{StringValue: "missing dn"},
154+
},
155+
},
156+
}
157+
require.NoError(t, err)
86158
require.Equal(t, v2.BatonActionStatus_BATON_ACTION_STATUS_FAILED, status)
159+
require.Equal(t, expectedRv, rv)
160+
}
161+
162+
func TestAsyncActionHandler(t *testing.T) {
163+
ctx := context.Background()
164+
m := NewActionManager(ctx)
165+
require.NotNil(t, m)
166+
167+
err := m.RegisterAction(ctx, "lock_account", testActionSchema, testAsyncActionHandler)
168+
require.NoError(t, err)
169+
170+
schemas, _, err := m.ListActionSchemas(ctx)
171+
require.NoError(t, err)
172+
require.Len(t, schemas, 1)
173+
require.Equal(t, testActionSchema, schemas[0])
174+
175+
schema, _, err := m.GetActionSchema(ctx, "lock_account")
176+
require.NoError(t, err)
177+
require.Equal(t, testActionSchema, schema)
178+
179+
actionId, status, rv, _, err := m.InvokeAction(ctx, "lock_account", testInput)
180+
require.NoError(t, err)
181+
require.Equal(t, v2.BatonActionStatus_BATON_ACTION_STATUS_RUNNING, status)
182+
require.Nil(t, rv)
183+
184+
status, name, _, _, err := m.GetActionStatus(ctx, actionId)
185+
require.NoError(t, err)
186+
require.Equal(t, "lock_account", name)
187+
require.Equal(t, v2.BatonActionStatus_BATON_ACTION_STATUS_RUNNING, status)
188+
189+
time.Sleep(1 * time.Second)
190+
191+
status, name, rv, _, err = m.GetActionStatus(ctx, actionId)
192+
require.NoError(t, err)
193+
require.Equal(t, "lock_account", name)
194+
require.Equal(t, v2.BatonActionStatus_BATON_ACTION_STATUS_COMPLETE, status)
195+
require.NotNil(t, rv)
196+
success, ok := rv.Fields["success"].GetKind().(*structpb.Value_BoolValue)
197+
require.True(t, ok)
198+
require.True(t, success.BoolValue)
199+
}
200+
201+
func TestActionHandlerGoroutineLeaks(t *testing.T) {
202+
// Test case 1: Normal completion should not leak goroutines
203+
t.Run("normal completion", func(t *testing.T) {
204+
ctx := context.Background()
205+
m := NewActionManager(ctx)
206+
require.NotNil(t, m)
207+
208+
err := m.RegisterAction(ctx, "lock_account", testActionSchema, testAsyncActionHandler)
209+
require.NoError(t, err)
210+
211+
// Get initial goroutine count
212+
initialCount := runtime.NumGoroutine()
213+
214+
actionId, status, _, _, err := m.InvokeAction(ctx, "lock_account", testInput)
215+
require.NoError(t, err)
216+
require.Equal(t, v2.BatonActionStatus_BATON_ACTION_STATUS_RUNNING, status)
217+
218+
// Wait for completion
219+
time.Sleep(1 * time.Second)
220+
221+
// Check final status
222+
status, name, _, _, err := m.GetActionStatus(ctx, actionId)
223+
require.NoError(t, err)
224+
require.Equal(t, "lock_account", name)
225+
require.Equal(t, v2.BatonActionStatus_BATON_ACTION_STATUS_COMPLETE, status)
226+
227+
// Give a small grace period for goroutines to clean up
228+
time.Sleep(100 * time.Millisecond)
229+
230+
// Verify no goroutine leaks
231+
finalCount := runtime.NumGoroutine()
232+
require.LessOrEqual(t, finalCount, initialCount+1, "goroutine leak detected after normal completion")
233+
})
234+
235+
// Test case 2: Cancelled context should not leak goroutines
236+
t.Run("context cancellation", func(t *testing.T) {
237+
ctx := context.Background()
238+
m := NewActionManager(ctx)
239+
require.NotNil(t, m)
240+
241+
err := m.RegisterAction(ctx, "lock_account", testActionSchema, testAsyncCancelActionHandler)
242+
require.NoError(t, err)
243+
244+
// Get initial goroutine count
245+
initialCount := runtime.NumGoroutine()
246+
247+
_, status, rv, _, err := m.InvokeAction(ctx, "lock_account", testInput)
248+
require.NoError(t, err)
249+
require.Equal(t, v2.BatonActionStatus_BATON_ACTION_STATUS_FAILED, status)
250+
251+
errMsg := rv.Fields["error"].GetKind().(*structpb.Value_StringValue).StringValue
252+
require.Contains(t, errMsg, "context canceled")
253+
254+
// Give a small grace period for goroutines to clean up
255+
time.Sleep(100 * time.Millisecond)
256+
257+
// Verify no goroutine leaks
258+
finalCount := runtime.NumGoroutine()
259+
require.LessOrEqual(t, finalCount, initialCount+1, "goroutine leak detected after context cancellation")
260+
})
87261
}

pkg/connectorbuilder/connectorbuilder.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ type CustomActionManager interface {
8585
ListActionSchemas(ctx context.Context) ([]*v2.BatonActionSchema, annotations.Annotations, error)
8686
GetActionSchema(ctx context.Context, name string) (*v2.BatonActionSchema, annotations.Annotations, error)
8787
InvokeAction(ctx context.Context, name string, args *structpb.Struct) (string, v2.BatonActionStatus, *structpb.Struct, annotations.Annotations, error)
88-
GetActionStatus(ctx context.Context, id string) (v2.BatonActionStatus, *structpb.Struct, annotations.Annotations, error)
88+
GetActionStatus(ctx context.Context, id string) (v2.BatonActionStatus, string, *structpb.Struct, annotations.Annotations, error)
8989
}
9090

9191
type RegisterActionManager interface {
@@ -1099,6 +1099,7 @@ func (b *builderImpl) InvokeAction(ctx context.Context, request *v2.InvokeAction
10991099

11001100
rv := &v2.InvokeActionResponse{
11011101
Id: id,
1102+
Name: request.GetName(),
11021103
Status: status,
11031104
Annotations: annos,
11041105
Response: resp,
@@ -1119,15 +1120,15 @@ func (b *builderImpl) GetActionStatus(ctx context.Context, request *v2.GetAction
11191120
return nil, fmt.Errorf("error: action manager not implemented")
11201121
}
11211122

1122-
status, rv, annos, err := b.actionManager.GetActionStatus(ctx, request.GetId())
1123+
status, name, rv, annos, err := b.actionManager.GetActionStatus(ctx, request.GetId())
11231124
if err != nil {
11241125
b.m.RecordTaskFailure(ctx, tt, b.nowFunc().Sub(start))
11251126
return nil, fmt.Errorf("error: getting action status failed: %w", err)
11261127
}
11271128

11281129
resp := &v2.GetActionStatusResponse{
11291130
Id: request.GetId(),
1130-
Name: request.GetName(),
1131+
Name: name,
11311132
Status: status,
11321133
Annotations: annos,
11331134
Response: rv,

proto/c1/connector/v2/action.proto

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ message InvokeActionResponse {
4444
BatonActionStatus status = 2;
4545
repeated google.protobuf.Any annotations = 3;
4646
google.protobuf.Struct response = 4;
47+
string name = 5;
4748
}
4849

4950

5051
message GetActionStatusRequest {
51-
string name = 1;
52+
string name = 1 [deprecated = true];
5253
string id = 2;
5354
repeated google.protobuf.Any annotations = 3;
5455
}

0 commit comments

Comments
 (0)