@@ -52,6 +52,8 @@ type ScalarFuncConfig struct {
5252// bindData holds bind data accessible during execution.
5353type bindData struct {
5454 connId uint64
55+ // We ignore the linter because we need to pass the context through C memory.
56+ ctx context.Context //nolint:containedctx
5557}
5658
5759// ScalarUDFArg contains scalar UDF argument metadata and the optional argument.
@@ -69,9 +71,11 @@ type (
6971 // RowContextExecutorFn accepts a row-based execution function using a context.
7072 // It takes a context and the row values, and returns the row execution result, or error.
7173 RowContextExecutorFn func (ctx context.Context , values []driver.Value ) (any , error )
72- // ScalarBinderFn takes a context and the scalar function's arguments.
73- // It returns the updated context, which can now contain arbitrary data available during execution.
74- ScalarBinderFn func (ctx context.Context , args []ScalarUDFArg ) (context.Context , error )
74+ // ScalarBinderFn takes a (parent) context and the scalar function's arguments.
75+ // It returns the possibly updated child context (can be the same as the parent).
76+ // The child context can contain additional arbitrary data available during execution.
77+ // Please ensure correct context inheritance.
78+ ScalarBinderFn func (parentCtx context.Context , args []ScalarUDFArg ) (context.Context , error )
7579)
7680
7781// ScalarFuncExecutor contains the functions to execute a user-defined scalar function.
@@ -107,16 +111,27 @@ func (s *scalarFuncContext) Config() ScalarFuncConfig {
107111
108112// RowExecutor returns a RowExecutorFn executing the scalar function.
109113// It uses the bindInfo to get the execution context.
110- func (s * scalarFuncContext ) RowExecutor (info * bindData ) RowExecutorFn {
114+ func (s * scalarFuncContext ) RowExecutor (info * bindData ) ( RowExecutorFn , error ) {
111115 e := s .f .Executor ()
112116 if e .RowExecutor != nil {
113- return e .RowExecutor
117+ return e .RowExecutor , nil
114118 }
115- ctx := s .ctxStore .load (info .connId )
116119
117- return func (values []driver.Value ) (any , error ) {
118- return e .RowContextExecutor (ctx , values )
120+ // Parent context cancellation propagates to children,
121+ // therefore, it is enough to check the child context here.
122+ if info .ctx != nil {
123+ if err := info .ctx .Err (); err != nil {
124+ return nil , err
125+ }
126+ } else {
127+ // No child context means that there is no custom bind function.
128+ // Retrieve the parent context from the connection context store.
129+ info .ctx = s .ctxStore .load (info .connId )
119130 }
131+
132+ return func (values []driver.Value ) (any , error ) {
133+ return e .RowContextExecutor (info .ctx , values )
134+ }, nil
120135}
121136
122137// RegisterScalarUDF registers a user-defined scalar function.
@@ -212,10 +227,14 @@ func scalar_udf_callback(functionInfoPtr, inputPtr, outputPtr unsafe.Pointer) {
212227 values := make ([]driver.Value , length )
213228
214229 // Execute the user-defined scalar function for each row.
215- f := funcCtx .RowExecutor (pinnedBindData )
230+ f , err := funcCtx .RowExecutor (pinnedBindData )
231+ if err != nil {
232+ mapping .ScalarFunctionSetError (functionInfo , getError (errAPI , err ).Error ())
233+ return
234+ }
235+
216236 for rowIdx := range inputChunk .GetSize () {
217237 // Get each column value.
218- var err error
219238 nullRow := false
220239 for colIdx := range length {
221240 if values [colIdx ], err = inputChunk .GetValue (colIdx , rowIdx ); err != nil {
@@ -283,6 +302,8 @@ func scalar_udf_bind_callback(bindInfoPtr unsafe.Pointer) {
283302 mapping .ScalarFunctionGetClientContext (bindInfo , & clientCtx )
284303 defer mapping .DestroyClientContext (& clientCtx )
285304
305+ // We need the connId to retrieve the correct parent context.
306+ // Then, we store the child context in data.
286307 connId := mapping .ClientContextGetConnectionId (clientCtx )
287308 data := bindData {connId : uint64 (connId )}
288309
@@ -291,11 +312,12 @@ func scalar_udf_bind_callback(bindInfoPtr unsafe.Pointer) {
291312
292313 // Get any custom bind data by invoking the custom bind function.
293314 if funcCtx .f .Executor ().ScalarBinder != nil {
294- err := funcCtx .bind (clientCtx , bindInfo , uint64 (connId ))
315+ bindCtx , err := funcCtx .bind (clientCtx , bindInfo , uint64 (connId ))
295316 if err != nil {
296317 mapping .ScalarFunctionBindSetError (bindInfo , err .Error ())
297318 return
298319 }
320+ data .ctx = bindCtx
299321 }
300322
301323 // Set the copy callback of the bind info.
@@ -343,26 +365,30 @@ func getScalarUDFArg(clientCtx mapping.ClientContext, bindInfo mapping.BindInfo,
343365 return arg , nil
344366}
345367
346- func (s * scalarFuncContext ) bind (clientCtx mapping.ClientContext , bindInfo mapping.BindInfo , connId uint64 ) error {
368+ func (s * scalarFuncContext ) bind (clientCtx mapping.ClientContext , bindInfo mapping.BindInfo , connId uint64 ) (context. Context , error ) {
347369 ctx := s .ctxStore .load (connId )
348370 argCount := mapping .ScalarFunctionBindGetArgumentCount (bindInfo )
349371
350372 var args []ScalarUDFArg
351373 for i := range int (argCount ) {
352374 arg , err := getScalarUDFArg (clientCtx , bindInfo , i )
353375 if err != nil {
354- return err
376+ return nil , err
355377 }
356378 args = append (args , arg )
357379 }
358380
359381 bindCtx , err := s .f .Executor ().ScalarBinder (ctx , args )
360382 if err != nil {
361- return err
383+ return nil , err
362384 }
363385
364- s .ctxStore .store (connId , bindCtx , true )
365- return nil
386+ // Propagate the parent context, if the child context is nil.
387+ if ctx != nil && bindCtx == nil {
388+ bindCtx = ctx
389+ }
390+
391+ return bindCtx , nil
366392}
367393
368394func registerInputParams (config ScalarFuncConfig , f mapping.ScalarFunction ) error {
0 commit comments