diff --git a/go/core/flow.go b/go/core/flow.go index ea514365c2..1a3526b445 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -154,6 +154,10 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func( return nil } output, err := (*ActionDef[In, Out, Stream])(f).Run(ctx, input, cb) + if errors.Is(err, errStop) { + // Consumer broke out of the loop; don't yield again. + return + } if err != nil { yield(nil, err) } else { diff --git a/go/genkit/x/genkit.go b/go/genkit/x/genkit.go new file mode 100644 index 0000000000..3e962ad683 --- /dev/null +++ b/go/genkit/x/genkit.go @@ -0,0 +1,143 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package x provides experimental Genkit APIs. +// +// APIs in this package are under active development and may change in any +// minor version release. Use with caution in production environments. +// +// When these APIs stabilize, they will be moved to the genkit package +// and these exports will be deprecated. +package x + +import ( + "context" + + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" +) + +// StreamingFunc is a streaming function that uses a channel instead of a callback. +// +// The function receives a send-only channel to which it should write stream chunks. +// The channel is managed by the framework and will be closed automatically after +// the function returns. The function should NOT close the channel itself. +// +// When writing to the channel, the function should respect context cancellation: +// +// select { +// case streamCh <- chunk: +// case <-ctx.Done(): +// return zero, ctx.Err() +// } +type StreamingFunc[In, Out, Stream any] = func(ctx context.Context, input In, streamCh chan<- Stream) (Out, error) + +// DefineStreamingFlow defines a streaming flow that uses a channel for streaming, +// registers it as a [core.Action] of type Flow, and returns a [core.Flow] runner. +// +// Unlike [genkit.DefineStreamingFlow] which uses a callback, this function accepts +// a [StreamingFunc] that writes stream chunks to a channel. This can be +// more ergonomic when integrating with other channel-based APIs or when the +// streaming logic is more naturally expressed with channels. +// +// The channel passed to the function is unbuffered and managed by the framework. +// The function should NOT close the channel - it will be closed automatically +// after the function returns. +// +// Example: +// +// countdown := x.DefineStreamingFlow(g, "countdown", +// func(ctx context.Context, start int, streamCh chan<- int) (string, error) { +// for i := start; i > 0; i-- { +// select { +// case streamCh <- i: +// case <-ctx.Done(): +// return "", ctx.Err() +// } +// } +// return "liftoff!", nil +// }) +// +// // Run with streaming +// for val, err := range countdown.Stream(ctx, 5) { +// if err != nil { +// log.Fatal(err) +// } +// if val.Done { +// fmt.Println(val.Output) // "liftoff!" +// } else { +// fmt.Println(val.Stream) // 5, 4, 3, 2, 1 +// } +// } +func DefineStreamingFlow[In, Out, Stream any](g *genkit.Genkit, name string, fn StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] { + // Wrap the channel-based function to work with the callback-based API + wrappedFn := func(ctx context.Context, input In, sendChunk core.StreamCallback[Stream]) (Out, error) { + if sendChunk == nil { + // Create a channel that discards all values + discardCh := make(chan Stream) + go func() { + for range discardCh { + } + }() + output, err := fn(ctx, input, discardCh) + close(discardCh) + return output, err + } + + // Create a cancellable context for the user function. + // We cancel this if the callback returns an error, signaling + // the user's function to stop producing chunks. + fnCtx, cancel := context.WithCancel(ctx) + defer cancel() + + streamCh := make(chan Stream) + + type result struct { + output Out + err error + } + resultCh := make(chan result, 1) + + go func() { + output, err := fn(fnCtx, input, streamCh) + close(streamCh) + resultCh <- result{output, err} + }() + + // Forward chunks from the channel to the callback. + // If callback returns an error, cancel context and drain remaining + // chunks to prevent the goroutine from blocking. + var callbackErr error + for chunk := range streamCh { + if callbackErr != nil { + continue + } + if err := sendChunk(ctx, chunk); err != nil { + callbackErr = err + cancel() + } + } + + res := <-resultCh + if callbackErr != nil { + var zero Out + return zero, callbackErr + } + return res.output, res.err + } + + return genkit.DefineStreamingFlow(g, name, wrappedFn) +} diff --git a/go/genkit/x/genkit_test.go b/go/genkit/x/genkit_test.go new file mode 100644 index 0000000000..81b4d15105 --- /dev/null +++ b/go/genkit/x/genkit_test.go @@ -0,0 +1,220 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "context" + "errors" + "slices" + "sync/atomic" + "testing" + "time" + + "github.com/firebase/genkit/go/genkit" +) + +func TestDefineStreamingFlow(t *testing.T) { + t.Run("streams values via channel", func(t *testing.T) { + ctx := context.Background() + g := genkit.Init(ctx) + + flow := DefineStreamingFlow(g, "test/counter", func(ctx context.Context, n int, stream chan<- int) (string, error) { + for i := 0; i < n; i++ { + select { + case stream <- i: + case <-ctx.Done(): + return "", ctx.Err() + } + } + return "done", nil + }) + + var streamedValues []int + var finalOutput string + + for v, err := range flow.Stream(ctx, 3) { + if err != nil { + t.Fatalf("Stream error: %v", err) + } + if v.Done { + finalOutput = v.Output + } else { + streamedValues = append(streamedValues, v.Stream) + } + } + + wantStreamed := []int{0, 1, 2} + if !slices.Equal(streamedValues, wantStreamed) { + t.Errorf("streamed values = %v, want %v", streamedValues, wantStreamed) + } + if finalOutput != "done" { + t.Errorf("final output = %q, want %q", finalOutput, "done") + } + }) + + t.Run("runs without streaming", func(t *testing.T) { + ctx := context.Background() + g := genkit.Init(ctx) + + flow := DefineStreamingFlow(g, "test/nostream", func(ctx context.Context, n int, stream chan<- int) (string, error) { + for i := 0; i < n; i++ { + stream <- i + } + return "complete", nil + }) + + output, err := flow.Run(ctx, 3) + if err != nil { + t.Fatalf("Run error: %v", err) + } + if output != "complete" { + t.Errorf("output = %q, want %q", output, "complete") + } + }) + + t.Run("handles errors", func(t *testing.T) { + ctx := context.Background() + g := genkit.Init(ctx) + + expectedErr := errors.New("flow failed") + flow := DefineStreamingFlow(g, "test/failing", func(ctx context.Context, _ int, stream chan<- int) (string, error) { + return "", expectedErr + }) + + var gotErr error + for _, err := range flow.Stream(ctx, 1) { + if err != nil { + gotErr = err + } + } + + if gotErr == nil { + t.Error("expected error, got nil") + } + }) + + t.Run("handles context cancellation", func(t *testing.T) { + ctx := context.Background() + g := genkit.Init(ctx) + + flow := DefineStreamingFlow(g, "test/cancel", func(ctx context.Context, n int, stream chan<- int) (int, error) { + for i := 0; i < n; i++ { + select { + case stream <- i: + case <-ctx.Done(): + return 0, ctx.Err() + } + } + return n, nil + }) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + var gotErr error + for _, err := range flow.Stream(cancelCtx, 100) { + if err != nil { + gotErr = err + } + } + + if gotErr == nil { + t.Error("expected context cancellation error, got nil") + } + }) + + t.Run("handles empty stream", func(t *testing.T) { + ctx := context.Background() + g := genkit.Init(ctx) + + flow := DefineStreamingFlow(g, "test/empty", func(ctx context.Context, _ struct{}, stream chan<- int) (string, error) { + return "empty", nil + }) + + var streamedValues []int + var finalOutput string + + for v, err := range flow.Stream(ctx, struct{}{}) { + if err != nil { + t.Fatalf("Stream error: %v", err) + } + if v.Done { + finalOutput = v.Output + } else { + streamedValues = append(streamedValues, v.Stream) + } + } + + if len(streamedValues) != 0 { + t.Errorf("streamed values = %v, want empty", streamedValues) + } + if finalOutput != "empty" { + t.Errorf("final output = %q, want %q", finalOutput, "empty") + } + }) + + t.Run("handles consumer breaking early", func(t *testing.T) { + ctx := context.Background() + g := genkit.Init(ctx) + + var produced atomic.Int32 + flow := DefineStreamingFlow(g, "test/earlybreak", func(ctx context.Context, n int, stream chan<- int) (string, error) { + for i := 0; i < n; i++ { + select { + case stream <- i: + produced.Add(1) + case <-ctx.Done(): + return "cancelled", ctx.Err() + } + } + return "done", nil + }) + + var received []int + done := make(chan struct{}) + go func() { + defer close(done) + for v, err := range flow.Stream(ctx, 1000) { + if err != nil { + return + } + if !v.Done { + received = append(received, v.Stream) + if len(received) >= 3 { + break // Stop early + } + } + } + }() + + // Should complete without deadlock + select { + case <-done: + // Success - no deadlock + case <-time.After(2 * time.Second): + t.Fatal("timeout - likely deadlock when consumer breaks early") + } + + if len(received) != 3 { + t.Errorf("received %d values, want 3", len(received)) + } + + // Producer should have been signaled to stop (though may have produced a few more) + // The important thing is no deadlock occurred + t.Logf("producer created %d chunks before stopping", produced.Load()) + }) +}