Skip to content

Commit 8400602

Browse files
authored
Simple sending of return value to XCom (apache#56481)
This right now is very simple and will need to evolve over time -- it doens't support XCom backends for instance. We also don't handle any of the "serialization" format for more advanced types, such as datetime/time.Time etc. That will come later
1 parent 463d51c commit 8400602

File tree

12 files changed

+989
-870
lines changed

12 files changed

+989
-870
lines changed

go-sdk/bundle/bundlev1/registry.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import (
2828
)
2929

3030
type (
31-
// task is an interface of an task implementation.
3231
Task = worker.Task
3332
Bundle = worker.Bundle
3433

go-sdk/bundle/bundlev1/task.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import (
2424
"reflect"
2525
"runtime"
2626

27+
"github.com/apache/airflow/go-sdk/pkg/api"
28+
"github.com/apache/airflow/go-sdk/pkg/sdkcontext"
2729
"github.com/apache/airflow/go-sdk/sdk"
2830
)
2931

@@ -43,6 +45,7 @@ func NewTaskFunction(fn any) (Task, error) {
4345

4446
func (f *taskFunction) Execute(ctx context.Context, logger *slog.Logger) error {
4547
fnType := f.fn.Type()
48+
sdkClient := sdk.NewClient()
4649

4750
reflectArgs := make([]reflect.Value, fnType.NumIn())
4851
for i := range reflectArgs {
@@ -54,7 +57,7 @@ func (f *taskFunction) Execute(ctx context.Context, logger *slog.Logger) error {
5457
case isLogger(in):
5558
reflectArgs[i] = reflect.ValueOf(logger)
5659
case isClient(in):
57-
reflectArgs[i] = reflect.ValueOf(sdk.NewClient())
60+
reflectArgs[i] = reflect.ValueOf(sdkClient)
5861
default:
5962
// TODO: deal with other value types. For now they will all be Zero values unless it's a context
6063
reflectArgs[i] = reflect.Zero(in)
@@ -74,15 +77,26 @@ func (f *taskFunction) Execute(ctx context.Context, logger *slog.Logger) error {
7477
}
7578
}
7679
// If there are two results, convert the first only if it's not a nil pointer
77-
var res any
7880
if len(retValues) > 1 && (retValues[0].Kind() != reflect.Ptr || !retValues[0].IsNil()) {
79-
res = retValues[0].Interface()
81+
res := retValues[0].Interface()
82+
f.sendXcom(ctx, res, sdkClient, logger)
8083
}
81-
// TODO: send the result to XCom
82-
_ = res
8384
return err
8485
}
8586

87+
func (f *taskFunction) sendXcom(
88+
ctx context.Context,
89+
value any,
90+
c sdk.XComClient,
91+
logger *slog.Logger,
92+
) {
93+
workload := ctx.Value(sdkcontext.WorkloadContextKey).(api.ExecuteTaskWorkload)
94+
err := c.PushXCom(ctx, workload.TI, api.XComReturnValueKey, value)
95+
if err != nil {
96+
logger.ErrorContext(ctx, "Unable to set XCom", "err", err)
97+
}
98+
}
99+
86100
func (f *taskFunction) validateFn(fnType reflect.Type) error {
87101
if fnType.Kind() != reflect.Func {
88102
return fmt.Errorf("expected a func as input but was %s", fnType.Kind())

go-sdk/example/bundle/main.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"context"
2222
"fmt"
2323
"log/slog"
24+
"runtime"
2425
"time"
2526

2627
v1 "github.com/apache/airflow/go-sdk/bundle/bundlev1"
@@ -56,7 +57,7 @@ func main() {
5657
bundlev1server.Serve(&myBundle{})
5758
}
5859

59-
func extract(ctx context.Context, client sdk.Client, log *slog.Logger) error {
60+
func extract(ctx context.Context, client sdk.Client, log *slog.Logger) (any, error) {
6061
log.Info("Hello from task")
6162
conn, err := client.GetConnection(ctx, "test_http")
6263
if err != nil {
@@ -69,15 +70,19 @@ func extract(ctx context.Context, client sdk.Client, log *slog.Logger) error {
6970
// Once per loop,.check if we've been asked to cancel!
7071
select {
7172
case <-ctx.Done():
72-
return ctx.Err()
73+
return nil, ctx.Err()
7374
default:
7475
}
7576
log.Info("After the beep the time will be", "time", time.Now())
7677
time.Sleep(2 * time.Second)
7778
}
7879
log.Info("Goodbye from task")
7980

80-
return nil
81+
ret := map[string]any{
82+
"go_version": runtime.Version(),
83+
}
84+
85+
return ret, nil
8186
}
8287

8388
func transform(ctx context.Context, client sdk.VariableClient, log *slog.Logger) error {

0 commit comments

Comments
 (0)