Skip to content

Commit 1a6cd82

Browse files
Merge pull request #114 from taniabogatsch/context-propagation
[Fix] Context propagation in custom scalar UDF bind functions
2 parents e74fa4f + 451baf0 commit 1a6cd82

File tree

2 files changed

+101
-20
lines changed

2 files changed

+101
-20
lines changed

scalar_udf.go

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ type ScalarFuncConfig struct {
5252
// bindData holds bind data accessible during execution.
5353
type bindData struct {
5454
connId uint64
55+
// We ignore the linter because we need to pass the context through C memory.
56+
ctx context.Context //nolint:containedctx
5557
}
5658

5759
// ScalarUDFArg contains scalar UDF argument metadata and the optional argument.
@@ -69,9 +71,11 @@ type (
6971
// RowContextExecutorFn accepts a row-based execution function using a context.
7072
// It takes a context and the row values, and returns the row execution result, or error.
7173
RowContextExecutorFn func(ctx context.Context, values []driver.Value) (any, error)
72-
// ScalarBinderFn takes a context and the scalar function's arguments.
73-
// It returns the updated context, which can now contain arbitrary data available during execution.
74-
ScalarBinderFn func(ctx context.Context, args []ScalarUDFArg) (context.Context, error)
74+
// ScalarBinderFn takes a (parent) context and the scalar function's arguments.
75+
// It returns the possibly updated child context (can be the same as the parent).
76+
// The child context can contain additional arbitrary data available during execution.
77+
// Please ensure correct context inheritance.
78+
ScalarBinderFn func(parentCtx context.Context, args []ScalarUDFArg) (context.Context, error)
7579
)
7680

7781
// ScalarFuncExecutor contains the functions to execute a user-defined scalar function.
@@ -107,16 +111,27 @@ func (s *scalarFuncContext) Config() ScalarFuncConfig {
107111

108112
// RowExecutor returns a RowExecutorFn executing the scalar function.
109113
// It uses the bindInfo to get the execution context.
110-
func (s *scalarFuncContext) RowExecutor(info *bindData) RowExecutorFn {
114+
func (s *scalarFuncContext) RowExecutor(info *bindData) (RowExecutorFn, error) {
111115
e := s.f.Executor()
112116
if e.RowExecutor != nil {
113-
return e.RowExecutor
117+
return e.RowExecutor, nil
114118
}
115-
ctx := s.ctxStore.load(info.connId)
116119

117-
return func(values []driver.Value) (any, error) {
118-
return e.RowContextExecutor(ctx, values)
120+
// Parent context cancellation propagates to children,
121+
// therefore, it is enough to check the child context here.
122+
if info.ctx != nil {
123+
if err := info.ctx.Err(); err != nil {
124+
return nil, err
125+
}
126+
} else {
127+
// No child context means that there is no custom bind function.
128+
// Retrieve the parent context from the connection context store.
129+
info.ctx = s.ctxStore.load(info.connId)
119130
}
131+
132+
return func(values []driver.Value) (any, error) {
133+
return e.RowContextExecutor(info.ctx, values)
134+
}, nil
120135
}
121136

122137
// RegisterScalarUDF registers a user-defined scalar function.
@@ -212,10 +227,14 @@ func scalar_udf_callback(functionInfoPtr, inputPtr, outputPtr unsafe.Pointer) {
212227
values := make([]driver.Value, length)
213228

214229
// Execute the user-defined scalar function for each row.
215-
f := funcCtx.RowExecutor(pinnedBindData)
230+
f, err := funcCtx.RowExecutor(pinnedBindData)
231+
if err != nil {
232+
mapping.ScalarFunctionSetError(functionInfo, getError(errAPI, err).Error())
233+
return
234+
}
235+
216236
for rowIdx := range inputChunk.GetSize() {
217237
// Get each column value.
218-
var err error
219238
nullRow := false
220239
for colIdx := range length {
221240
if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil {
@@ -283,6 +302,8 @@ func scalar_udf_bind_callback(bindInfoPtr unsafe.Pointer) {
283302
mapping.ScalarFunctionGetClientContext(bindInfo, &clientCtx)
284303
defer mapping.DestroyClientContext(&clientCtx)
285304

305+
// We need the connId to retrieve the correct parent context.
306+
// Then, we store the child context in data.
286307
connId := mapping.ClientContextGetConnectionId(clientCtx)
287308
data := bindData{connId: uint64(connId)}
288309

@@ -291,11 +312,12 @@ func scalar_udf_bind_callback(bindInfoPtr unsafe.Pointer) {
291312

292313
// Get any custom bind data by invoking the custom bind function.
293314
if funcCtx.f.Executor().ScalarBinder != nil {
294-
err := funcCtx.bind(clientCtx, bindInfo, uint64(connId))
315+
bindCtx, err := funcCtx.bind(clientCtx, bindInfo, uint64(connId))
295316
if err != nil {
296317
mapping.ScalarFunctionBindSetError(bindInfo, err.Error())
297318
return
298319
}
320+
data.ctx = bindCtx
299321
}
300322

301323
// Set the copy callback of the bind info.
@@ -343,26 +365,30 @@ func getScalarUDFArg(clientCtx mapping.ClientContext, bindInfo mapping.BindInfo,
343365
return arg, nil
344366
}
345367

346-
func (s *scalarFuncContext) bind(clientCtx mapping.ClientContext, bindInfo mapping.BindInfo, connId uint64) error {
368+
func (s *scalarFuncContext) bind(clientCtx mapping.ClientContext, bindInfo mapping.BindInfo, connId uint64) (context.Context, error) {
347369
ctx := s.ctxStore.load(connId)
348370
argCount := mapping.ScalarFunctionBindGetArgumentCount(bindInfo)
349371

350372
var args []ScalarUDFArg
351373
for i := range int(argCount) {
352374
arg, err := getScalarUDFArg(clientCtx, bindInfo, i)
353375
if err != nil {
354-
return err
376+
return nil, err
355377
}
356378
args = append(args, arg)
357379
}
358380

359381
bindCtx, err := s.f.Executor().ScalarBinder(ctx, args)
360382
if err != nil {
361-
return err
383+
return nil, err
362384
}
363385

364-
s.ctxStore.store(connId, bindCtx, true)
365-
return nil
386+
// Propagate the parent context, if the child context is nil.
387+
if ctx != nil && bindCtx == nil {
388+
bindCtx = ctx
389+
}
390+
391+
return bindCtx, nil
366392
}
367393

368394
func registerInputParams(config ScalarFuncConfig, f mapping.ScalarFunction) error {

scalar_udf_test.go

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ var currentInfo TypeInfo
1616

1717
type testCtxKeyType string
1818

19+
type counterKeyType struct{}
20+
1921
const (
2022
testCtxKey testCtxKeyType = "test_ctx_key"
2123
testBindCtxKey testCtxKeyType = "test_bind_ctx_key"
@@ -31,6 +33,7 @@ type (
3133
unionTestSUDF struct{}
3234
getConnIdUDF struct{}
3335
easterEggUDF struct{}
36+
siblingUDF struct{}
3437
errExecutorSUDF struct{}
3538
errInputNilSUDF struct{}
3639
errResultNilSUDF struct{}
@@ -122,21 +125,21 @@ func getEasterEgg(ctx context.Context, values []driver.Value) (any, error) {
122125
return strconv.Itoa(int(customBindData)), nil
123126
}
124127

125-
func bindEasterEgg(ctx context.Context, args []ScalarUDFArg) (context.Context, error) {
128+
func bindEasterEgg(parentCtx context.Context, args []ScalarUDFArg) (context.Context, error) {
126129
if !args[1].Foldable {
127130
return nil, errors.New("second argument must be foldable for bindEasterEgg")
128131
}
129132
if args[1].Value == nil {
130-
bindCtx := context.WithValue(ctx, testBindCtxKey, uint64(0))
133+
bindCtx := context.WithValue(parentCtx, testBindCtxKey, uint64(0))
131134
return bindCtx, nil
132135
}
133136

134137
switch v := args[1].Value.(type) {
135138
case int32:
136-
bindCtx := context.WithValue(ctx, testBindCtxKey, uint64(v))
139+
bindCtx := context.WithValue(parentCtx, testBindCtxKey, uint64(v))
137140
return bindCtx, nil
138141
case uint64:
139-
bindCtx := context.WithValue(ctx, testBindCtxKey, v)
142+
bindCtx := context.WithValue(parentCtx, testBindCtxKey, v)
140143
return bindCtx, nil
141144
default:
142145
return nil, errors.New("cannot cast second argument to uint64")
@@ -279,6 +282,38 @@ func (*unionTestSUDF) Executor() ScalarFuncExecutor {
279282
}
280283
}
281284

285+
func (u *siblingUDF) Config() ScalarFuncConfig {
286+
inputType, _ := NewTypeInfo(TYPE_BIGINT)
287+
resultType, _ := NewTypeInfo(TYPE_BIGINT)
288+
return ScalarFuncConfig{
289+
InputTypeInfos: []TypeInfo{inputType},
290+
ResultTypeInfo: resultType,
291+
}
292+
}
293+
294+
func (u *siblingUDF) Executor() ScalarFuncExecutor {
295+
return ScalarFuncExecutor{
296+
// ScalarBinder is called once per expression during query planning.
297+
// The context is stored per-connection, but we need to ensure that the second call does not see the first call's state.
298+
ScalarBinder: func(parentCtx context.Context, args []ScalarUDFArg) (context.Context, error) {
299+
// Get the current counter from the context (default is 0).
300+
counter, _ := parentCtx.Value(counterKeyType{}).(int)
301+
counter++
302+
// Return an updated context with the new counter.
303+
return context.WithValue(parentCtx, counterKeyType{}, counter), nil
304+
},
305+
306+
// RowContextExecutor is called for each row during query execution.
307+
// It receives the context stored through its specific binder.
308+
RowContextExecutor: func(ctx context.Context, values []driver.Value) (any, error) {
309+
counter := ctx.Value(counterKeyType{}).(int)
310+
arg := values[0].(int64)
311+
res := arg * int64(counter)
312+
return res, nil
313+
},
314+
}
315+
}
316+
282317
func TestSimpleScalarUDF(t *testing.T) {
283318
db := openDbWrapper(t, ``)
284319
defer closeDbWrapper(t, db)
@@ -618,6 +653,26 @@ func TestBindScalarUDF(t *testing.T) {
618653
require.ErrorContains(t, err, "second argument must be foldable for bindEasterEgg")
619654
}
620655

656+
func TestSiblingUDFs(t *testing.T) {
657+
// SiblingUDF is used to test non-shared state behavior across multiple UDF calls in one query.
658+
db := openDbWrapper(t, ``)
659+
defer closeDbWrapper(t, db)
660+
661+
conn := openConnWrapper(t, db, context.Background())
662+
defer closeConnWrapper(t, conn)
663+
664+
var udf *siblingUDF
665+
err := RegisterScalarUDF(conn, "sibling_udf", udf)
666+
require.NoError(t, err)
667+
668+
query := `SELECT sibling_udf(1) + sibling_udf(2) AS res`
669+
670+
var res int64
671+
err = conn.QueryRowContext(context.Background(), query).Scan(&res)
672+
require.NoError(t, err)
673+
require.Equal(t, int64(3), res)
674+
}
675+
621676
func TestErrScalarUDF(t *testing.T) {
622677
db := openDbWrapper(t, ``)
623678
defer closeDbWrapper(t, db)

0 commit comments

Comments
 (0)