Skip to content

Commit 9cf0da7

Browse files
Updating ResultIterator api and adding doneFunc per handler
1 parent a8d0d00 commit 9cf0da7

File tree

2 files changed

+82
-33
lines changed

2 files changed

+82
-33
lines changed

subscription.go

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,50 @@ import (
99
"github.com/graphql-go/graphql/language/ast"
1010
)
1111

12-
type ResultIteratorFn func(count int64, result *Result, doneFunc func())
12+
// ResultIteratorParams parameters passed to the result iterator handler
13+
type ResultIteratorParams struct {
14+
ResultCount int64 // number of results this iterator has processed
15+
Result *Result // the current result
16+
Done func() // Removes the current handler
17+
Cancel func() // Cancels the iterator, same as iterator.Cancel()
18+
}
19+
20+
// ResultIteratorFn a result iterator handler
21+
type ResultIteratorFn func(p ResultIteratorParams)
22+
23+
// holds subscription handler data
24+
type subscriptionHanlderConfig struct {
25+
handler ResultIteratorFn
26+
doneFunc func()
27+
}
1328

29+
// ResultIterator handles processing results from a chan *Result
1430
type ResultIterator struct {
15-
count int64
16-
wg sync.WaitGroup
17-
ctx context.Context
18-
ch chan *Result
19-
cancelFunc context.CancelFunc
20-
cancelled bool
21-
handlers []ResultIteratorFn
31+
currentHandlerID int64
32+
count int64
33+
wg sync.WaitGroup
34+
ctx context.Context
35+
ch chan *Result
36+
cancelFunc context.CancelFunc
37+
cancelled bool
38+
handlers map[int64]*subscriptionHanlderConfig
2239
}
2340

41+
// NewResultIterator creates a new iterator and starts handling message on the result channel
2442
func NewResultIterator(ctx context.Context, ch chan *Result) *ResultIterator {
2543
if ctx == nil {
2644
ctx = context.Background()
2745
}
2846

2947
cctx, cancelFunc := context.WithCancel(ctx)
3048
iterator := &ResultIterator{
31-
count: 0,
32-
ctx: cctx,
33-
ch: ch,
34-
cancelFunc: cancelFunc,
35-
cancelled: false,
36-
handlers: []ResultIteratorFn{},
49+
currentHandlerID: 0,
50+
count: 0,
51+
ctx: cctx,
52+
ch: ch,
53+
cancelFunc: cancelFunc,
54+
cancelled: false,
55+
handlers: map[int64]*subscriptionHanlderConfig{},
3756
}
3857

3958
go func() {
@@ -49,9 +68,14 @@ func NewResultIterator(ctx context.Context, ch chan *Result) *ResultIterator {
4968
iterator.wg.Add(1)
5069
iterator.count++
5170
iterator.wg.Done()
52-
for _, handler := range iterator.handlers {
71+
for _, h := range iterator.handlers {
5372
iterator.wg.Wait()
54-
handler(iterator.count, res, iterator.Done)
73+
h.handler(ResultIteratorParams{
74+
ResultCount: iterator.count,
75+
Result: res,
76+
Done: h.doneFunc,
77+
Cancel: iterator.Cancel,
78+
})
5579
}
5680
}
5781
}
@@ -60,17 +84,42 @@ func NewResultIterator(ctx context.Context, ch chan *Result) *ResultIterator {
6084
return iterator
6185
}
6286

63-
func (c *ResultIterator) ForEach(handler ResultIteratorFn) {
87+
// adds a new handler
88+
func (c *ResultIterator) addHandler(handler ResultIteratorFn) {
89+
c.wg.Add(1)
90+
handlerID := c.currentHandlerID + 1
91+
c.currentHandlerID = handlerID
92+
c.handlers[handlerID] = &subscriptionHanlderConfig{
93+
handler: handler,
94+
doneFunc: func() {
95+
c.removeHandler(handlerID)
96+
},
97+
}
98+
c.wg.Done()
99+
}
100+
101+
// removes a handler and cancels if no more handlers exist
102+
func (c *ResultIterator) removeHandler(handlerID int64) {
64103
c.wg.Add(1)
65-
c.handlers = append(c.handlers, handler)
104+
delete(c.handlers, handlerID)
105+
if len(c.handlers) == 0 {
106+
c.Cancel()
107+
}
66108
c.wg.Done()
67109
}
68110

69-
func (c *ResultIterator) Done() {
111+
// ForEach adds a handler and handles each message as they come
112+
func (c *ResultIterator) ForEach(handler ResultIteratorFn) {
113+
c.addHandler(handler)
114+
}
115+
116+
// Cancel cancels the iterator
117+
func (c *ResultIterator) Cancel() {
70118
c.cancelled = true
71119
c.cancelFunc()
72120
}
73121

122+
// SubscribeParams parameters for subscribing
74123
type SubscribeParams struct {
75124
Schema Schema
76125
Document *ast.Document

subscription_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ func TestSubscription(t *testing.T) {
7979
})
8080

8181
// test a subscribe that should fail due to no return value
82-
failIterator.ForEach(func(count int64, res *Result, doneFunc func()) {
83-
if !res.HasErrors() {
82+
failIterator.ForEach(func(p ResultIteratorParams) {
83+
if !p.Result.HasErrors() {
8484
t.Errorf("subscribe failed to catch nil result from subscribe")
85-
doneFunc()
85+
p.Done()
8686
return
8787
}
88-
doneFunc()
88+
p.Done()
8989
return
9090
})
9191

@@ -95,27 +95,27 @@ func TestSubscription(t *testing.T) {
9595
ContextValue: context.Background(),
9696
})
9797

98-
resultIterator.ForEach(func(count int64, res *Result, doneFunc func()) {
99-
if res.HasErrors() {
100-
t.Errorf("subscribe error(s): %v", res.Errors)
101-
doneFunc()
98+
resultIterator.ForEach(func(p ResultIteratorParams) {
99+
if p.Result.HasErrors() {
100+
t.Errorf("subscribe error(s): %v", p.Result.Errors)
101+
p.Done()
102102
return
103103
}
104104

105-
if res.Data != nil {
106-
data := res.Data.(map[string]interface{})["watch_count"]
107-
expected := fmt.Sprintf("count=%d", count)
105+
if p.Result.Data != nil {
106+
data := p.Result.Data.(map[string]interface{})["watch_count"]
107+
expected := fmt.Sprintf("count=%d", p.ResultCount)
108108
actual := fmt.Sprintf("%v", data)
109109
if actual != expected {
110110
t.Errorf("subscription result error: expected %q, actual %q", expected, actual)
111-
doneFunc()
111+
p.Done()
112112
return
113113
}
114114

115115
// test the done func by quitting after 3 iterations
116116
// the publisher will publish up to 5
117-
if count >= int64(maxPublish-2) {
118-
doneFunc()
117+
if p.ResultCount >= int64(maxPublish-2) {
118+
p.Done()
119119
return
120120
}
121121
}

0 commit comments

Comments
 (0)