Skip to content

Commit 515159d

Browse files
committed
Added channel-based streaming flows.
1 parent 590fa04 commit 515159d

File tree

3 files changed

+367
-0
lines changed

3 files changed

+367
-0
lines changed

go/core/flow.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(
154154
return nil
155155
}
156156
output, err := (*ActionDef[In, Out, Stream])(f).Run(ctx, input, cb)
157+
if errors.Is(err, errStop) {
158+
// Consumer broke out of the loop; don't yield again.
159+
return
160+
}
157161
if err != nil {
158162
yield(nil, err)
159163
} else {

go/genkit/x/genkit.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
// Package x provides experimental Genkit APIs.
18+
//
19+
// APIs in this package are under active development and may change in any
20+
// minor version release. Use with caution in production environments.
21+
//
22+
// When these APIs stabilize, they will be moved to the genkit package
23+
// and these exports will be deprecated.
24+
package x
25+
26+
import (
27+
"context"
28+
29+
"github.com/firebase/genkit/go/core"
30+
"github.com/firebase/genkit/go/genkit"
31+
)
32+
33+
// StreamingFunc is a streaming function that uses a channel instead of a callback.
34+
//
35+
// The function receives a send-only channel to which it should write stream chunks.
36+
// The channel is managed by the framework and will be closed automatically after
37+
// the function returns. The function should NOT close the channel itself.
38+
//
39+
// When writing to the channel, the function should respect context cancellation:
40+
//
41+
// select {
42+
// case streamCh <- chunk:
43+
// case <-ctx.Done():
44+
// return zero, ctx.Err()
45+
// }
46+
type StreamingFunc[In, Out, Stream any] = func(ctx context.Context, input In, streamCh chan<- Stream) (Out, error)
47+
48+
// DefineStreamingFlow defines a streaming flow that uses a channel for streaming,
49+
// registers it as a [core.Action] of type Flow, and returns a [core.Flow] runner.
50+
//
51+
// Unlike [genkit.DefineStreamingFlow] which uses a callback, this function accepts
52+
// a [StreamingFunc] that writes stream chunks to a channel. This can be
53+
// more ergonomic when integrating with other channel-based APIs or when the
54+
// streaming logic is more naturally expressed with channels.
55+
//
56+
// The channel passed to the function is unbuffered and managed by the framework.
57+
// The function should NOT close the channel - it will be closed automatically
58+
// after the function returns.
59+
//
60+
// Example:
61+
//
62+
// countdown := x.DefineStreamingFlow(g, "countdown",
63+
// func(ctx context.Context, start int, streamCh chan<- int) (string, error) {
64+
// for i := start; i > 0; i-- {
65+
// select {
66+
// case streamCh <- i:
67+
// case <-ctx.Done():
68+
// return "", ctx.Err()
69+
// }
70+
// }
71+
// return "liftoff!", nil
72+
// })
73+
//
74+
// // Run with streaming
75+
// for val, err := range countdown.Stream(ctx, 5) {
76+
// if err != nil {
77+
// log.Fatal(err)
78+
// }
79+
// if val.Done {
80+
// fmt.Println(val.Output) // "liftoff!"
81+
// } else {
82+
// fmt.Println(val.Stream) // 5, 4, 3, 2, 1
83+
// }
84+
// }
85+
func DefineStreamingFlow[In, Out, Stream any](g *genkit.Genkit, name string, fn StreamingFunc[In, Out, Stream]) *core.Flow[In, Out, Stream] {
86+
// Wrap the channel-based function to work with the callback-based API
87+
wrappedFn := func(ctx context.Context, input In, sendChunk core.StreamCallback[Stream]) (Out, error) {
88+
if sendChunk == nil {
89+
// Create a channel that discards all values
90+
discardCh := make(chan Stream)
91+
go func() {
92+
for range discardCh {
93+
}
94+
}()
95+
output, err := fn(ctx, input, discardCh)
96+
close(discardCh)
97+
return output, err
98+
}
99+
100+
// Create a cancellable context for the user function.
101+
// We cancel this if the callback returns an error, signaling
102+
// the user's function to stop producing chunks.
103+
fnCtx, cancel := context.WithCancel(ctx)
104+
defer cancel()
105+
106+
streamCh := make(chan Stream)
107+
108+
type result struct {
109+
output Out
110+
err error
111+
}
112+
resultCh := make(chan result, 1)
113+
114+
go func() {
115+
output, err := fn(fnCtx, input, streamCh)
116+
close(streamCh)
117+
resultCh <- result{output, err}
118+
}()
119+
120+
// Forward chunks from the channel to the callback.
121+
// If callback returns an error, cancel context and drain remaining
122+
// chunks to prevent the goroutine from blocking.
123+
var callbackErr error
124+
for chunk := range streamCh {
125+
if callbackErr != nil {
126+
continue
127+
}
128+
if err := sendChunk(ctx, chunk); err != nil {
129+
callbackErr = err
130+
cancel()
131+
}
132+
}
133+
134+
res := <-resultCh
135+
if callbackErr != nil {
136+
var zero Out
137+
return zero, callbackErr
138+
}
139+
return res.output, res.err
140+
}
141+
142+
return genkit.DefineStreamingFlow(g, name, wrappedFn)
143+
}

go/genkit/x/genkit_test.go

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
package x
18+
19+
import (
20+
"context"
21+
"errors"
22+
"slices"
23+
"sync/atomic"
24+
"testing"
25+
"time"
26+
27+
"github.com/firebase/genkit/go/genkit"
28+
)
29+
30+
func TestDefineStreamingFlow(t *testing.T) {
31+
t.Run("streams values via channel", func(t *testing.T) {
32+
ctx := context.Background()
33+
g := genkit.Init(ctx)
34+
35+
flow := DefineStreamingFlow(g, "test/counter", func(ctx context.Context, n int, stream chan<- int) (string, error) {
36+
for i := 0; i < n; i++ {
37+
select {
38+
case stream <- i:
39+
case <-ctx.Done():
40+
return "", ctx.Err()
41+
}
42+
}
43+
return "done", nil
44+
})
45+
46+
var streamedValues []int
47+
var finalOutput string
48+
49+
for v, err := range flow.Stream(ctx, 3) {
50+
if err != nil {
51+
t.Fatalf("Stream error: %v", err)
52+
}
53+
if v.Done {
54+
finalOutput = v.Output
55+
} else {
56+
streamedValues = append(streamedValues, v.Stream)
57+
}
58+
}
59+
60+
wantStreamed := []int{0, 1, 2}
61+
if !slices.Equal(streamedValues, wantStreamed) {
62+
t.Errorf("streamed values = %v, want %v", streamedValues, wantStreamed)
63+
}
64+
if finalOutput != "done" {
65+
t.Errorf("final output = %q, want %q", finalOutput, "done")
66+
}
67+
})
68+
69+
t.Run("runs without streaming", func(t *testing.T) {
70+
ctx := context.Background()
71+
g := genkit.Init(ctx)
72+
73+
flow := DefineStreamingFlow(g, "test/nostream", func(ctx context.Context, n int, stream chan<- int) (string, error) {
74+
for i := 0; i < n; i++ {
75+
stream <- i
76+
}
77+
return "complete", nil
78+
})
79+
80+
output, err := flow.Run(ctx, 3)
81+
if err != nil {
82+
t.Fatalf("Run error: %v", err)
83+
}
84+
if output != "complete" {
85+
t.Errorf("output = %q, want %q", output, "complete")
86+
}
87+
})
88+
89+
t.Run("handles errors", func(t *testing.T) {
90+
ctx := context.Background()
91+
g := genkit.Init(ctx)
92+
93+
expectedErr := errors.New("flow failed")
94+
flow := DefineStreamingFlow(g, "test/failing", func(ctx context.Context, _ int, stream chan<- int) (string, error) {
95+
return "", expectedErr
96+
})
97+
98+
var gotErr error
99+
for _, err := range flow.Stream(ctx, 1) {
100+
if err != nil {
101+
gotErr = err
102+
}
103+
}
104+
105+
if gotErr == nil {
106+
t.Error("expected error, got nil")
107+
}
108+
})
109+
110+
t.Run("handles context cancellation", func(t *testing.T) {
111+
ctx := context.Background()
112+
g := genkit.Init(ctx)
113+
114+
flow := DefineStreamingFlow(g, "test/cancel", func(ctx context.Context, n int, stream chan<- int) (int, error) {
115+
for i := 0; i < n; i++ {
116+
select {
117+
case stream <- i:
118+
case <-ctx.Done():
119+
return 0, ctx.Err()
120+
}
121+
}
122+
return n, nil
123+
})
124+
125+
cancelCtx, cancel := context.WithCancel(ctx)
126+
cancel()
127+
128+
var gotErr error
129+
for _, err := range flow.Stream(cancelCtx, 100) {
130+
if err != nil {
131+
gotErr = err
132+
}
133+
}
134+
135+
if gotErr == nil {
136+
t.Error("expected context cancellation error, got nil")
137+
}
138+
})
139+
140+
t.Run("handles empty stream", func(t *testing.T) {
141+
ctx := context.Background()
142+
g := genkit.Init(ctx)
143+
144+
flow := DefineStreamingFlow(g, "test/empty", func(ctx context.Context, _ struct{}, stream chan<- int) (string, error) {
145+
return "empty", nil
146+
})
147+
148+
var streamedValues []int
149+
var finalOutput string
150+
151+
for v, err := range flow.Stream(ctx, struct{}{}) {
152+
if err != nil {
153+
t.Fatalf("Stream error: %v", err)
154+
}
155+
if v.Done {
156+
finalOutput = v.Output
157+
} else {
158+
streamedValues = append(streamedValues, v.Stream)
159+
}
160+
}
161+
162+
if len(streamedValues) != 0 {
163+
t.Errorf("streamed values = %v, want empty", streamedValues)
164+
}
165+
if finalOutput != "empty" {
166+
t.Errorf("final output = %q, want %q", finalOutput, "empty")
167+
}
168+
})
169+
170+
t.Run("handles consumer breaking early", func(t *testing.T) {
171+
ctx := context.Background()
172+
g := genkit.Init(ctx)
173+
174+
var produced atomic.Int32
175+
flow := DefineStreamingFlow(g, "test/earlybreak", func(ctx context.Context, n int, stream chan<- int) (string, error) {
176+
for i := 0; i < n; i++ {
177+
select {
178+
case stream <- i:
179+
produced.Add(1)
180+
case <-ctx.Done():
181+
return "cancelled", ctx.Err()
182+
}
183+
}
184+
return "done", nil
185+
})
186+
187+
var received []int
188+
done := make(chan struct{})
189+
go func() {
190+
defer close(done)
191+
for v, err := range flow.Stream(ctx, 1000) {
192+
if err != nil {
193+
return
194+
}
195+
if !v.Done {
196+
received = append(received, v.Stream)
197+
if len(received) >= 3 {
198+
break // Stop early
199+
}
200+
}
201+
}
202+
}()
203+
204+
// Should complete without deadlock
205+
select {
206+
case <-done:
207+
// Success - no deadlock
208+
case <-time.After(2 * time.Second):
209+
t.Fatal("timeout - likely deadlock when consumer breaks early")
210+
}
211+
212+
if len(received) != 3 {
213+
t.Errorf("received %d values, want 3", len(received))
214+
}
215+
216+
// Producer should have been signaled to stop (though may have produced a few more)
217+
// The important thing is no deadlock occurred
218+
t.Logf("producer created %d chunks before stopping", produced.Load())
219+
})
220+
}

0 commit comments

Comments
 (0)