Skip to content

Commit ec949b3

Browse files
switching from mutex to waitgroups and adding a subscriber type with done channel
1 parent 4a374e3 commit ec949b3

File tree

2 files changed

+62
-29
lines changed

2 files changed

+62
-29
lines changed

subscription.go

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,30 @@ import (
99
"github.com/graphql-go/graphql/language/ast"
1010
)
1111

12+
// Subscriber subscriber
13+
type Subscriber struct {
14+
message chan interface{}
15+
done chan interface{}
16+
}
17+
18+
// Message returns the subscriber message channel
19+
func (c *Subscriber) Message() chan interface{} {
20+
return c.message
21+
}
22+
23+
// Done returns the subscriber done channel
24+
func (c *Subscriber) Done() chan interface{} {
25+
return c.done
26+
}
27+
28+
// NewSubscriber creates a new subscriber
29+
func NewSubscriber(message, done chan interface{}) *Subscriber {
30+
return &Subscriber{
31+
message: message,
32+
done: done,
33+
}
34+
}
35+
1236
// ResultIteratorParams parameters passed to the result iterator handler
1337
type ResultIteratorParams struct {
1438
ResultCount int64 // number of results this iterator has processed
@@ -30,21 +54,28 @@ type subscriptionHanlderConfig struct {
3054
type ResultIterator struct {
3155
currentHandlerID int64
3256
count int64
33-
wg sync.WaitGroup
34-
ctx context.Context
57+
mx sync.Mutex
3558
ch chan *Result
36-
cancelFunc context.CancelFunc
59+
iterDone chan interface{}
60+
subDone chan interface{}
3761
cancelled bool
3862
handlers map[int64]*subscriptionHanlderConfig
3963
}
4064

65+
func (c *ResultIterator) incrimentCount() int64 {
66+
c.mx.Lock()
67+
defer c.mx.Unlock()
68+
c.count++
69+
return c.count
70+
}
71+
4172
// NewResultIterator creates a new iterator and starts handling message on the result channel
42-
func NewResultIterator(ctx context.Context, cancelFunc context.CancelFunc, ch chan *Result) *ResultIterator {
73+
func NewResultIterator(subDone chan interface{}, ch chan *Result) *ResultIterator {
4374
iterator := &ResultIterator{
4475
currentHandlerID: 0,
4576
count: 0,
46-
ctx: ctx,
47-
cancelFunc: cancelFunc,
77+
iterDone: make(chan interface{}),
78+
subDone: subDone,
4879
ch: ch,
4980
cancelled: false,
5081
handlers: map[int64]*subscriptionHanlderConfig{},
@@ -53,20 +84,18 @@ func NewResultIterator(ctx context.Context, cancelFunc context.CancelFunc, ch ch
5384
go func() {
5485
for {
5586
select {
56-
case <-iterator.ctx.Done():
87+
case <-iterator.iterDone:
88+
subDone <- true
5789
return
5890
case res := <-iterator.ch:
5991
if iterator.cancelled {
6092
return
6193
}
62-
iterator.wg.Wait()
63-
iterator.wg.Add(1)
64-
iterator.count++
65-
iterator.wg.Done()
94+
95+
count := iterator.incrimentCount()
6696
for _, h := range iterator.handlers {
67-
iterator.wg.Wait()
6897
h.handler(ResultIteratorParams{
69-
ResultCount: iterator.count,
98+
ResultCount: int64(count),
7099
Result: res,
71100
Done: h.doneFunc,
72101
Cancel: iterator.Cancel,
@@ -81,7 +110,9 @@ func NewResultIterator(ctx context.Context, cancelFunc context.CancelFunc, ch ch
81110

82111
// adds a new handler
83112
func (c *ResultIterator) addHandler(handler ResultIteratorFn) {
84-
c.wg.Add(1)
113+
c.mx.Lock()
114+
defer c.mx.Unlock()
115+
85116
handlerID := c.currentHandlerID + 1
86117
c.currentHandlerID = handlerID
87118
c.handlers[handlerID] = &subscriptionHanlderConfig{
@@ -90,17 +121,17 @@ func (c *ResultIterator) addHandler(handler ResultIteratorFn) {
90121
c.removeHandler(handlerID)
91122
},
92123
}
93-
c.wg.Done()
94124
}
95125

96126
// removes a handler and cancels if no more handlers exist
97127
func (c *ResultIterator) removeHandler(handlerID int64) {
98-
c.wg.Add(1)
128+
c.mx.Lock()
129+
defer c.mx.Unlock()
130+
99131
delete(c.handlers, handlerID)
100132
if len(c.handlers) == 0 {
101133
c.Cancel()
102134
}
103-
c.wg.Done()
104135
}
105136

106137
// ForEach adds a handler and handles each message as they come
@@ -111,7 +142,7 @@ func (c *ResultIterator) ForEach(handler ResultIteratorFn) {
111142
// Cancel cancels the iterator
112143
func (c *ResultIterator) Cancel() {
113144
c.cancelled = true
114-
c.cancelFunc()
145+
c.iterDone <- true
115146
}
116147

117148
// SubscribeParams parameters for subscribing
@@ -129,30 +160,29 @@ type SubscribeParams struct {
129160
// Subscribe performs a subscribe operation
130161
func Subscribe(p SubscribeParams) *ResultIterator {
131162
resultChannel := make(chan *Result)
163+
doneChannel := make(chan interface{})
132164
// Use background context if no context was provided
133165
ctx := p.ContextValue
134166
if ctx == nil {
135167
ctx = context.Background()
136168
}
137169

138-
sctx, cancelFunc := context.WithCancel(ctx)
139-
140170
var mapSourceToResponse = func(payload interface{}) *Result {
141171
return Execute(ExecuteParams{
142172
Schema: p.Schema,
143173
Root: payload,
144174
AST: p.Document,
145175
OperationName: p.OperationName,
146176
Args: p.VariableValues,
147-
Context: sctx,
177+
Context: ctx,
148178
})
149179
}
150180

151181
go func() {
152-
153182
result := &Result{}
154183
defer func() {
155184
if err := recover(); err != nil {
185+
fmt.Println("SUBSCRIPTION RECOVERER", err)
156186
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
157187
}
158188
resultChannel <- result
@@ -165,7 +195,7 @@ func Subscribe(p SubscribeParams) *ResultIterator {
165195
OperationName: p.OperationName,
166196
Args: p.VariableValues,
167197
Result: result,
168-
Context: sctx,
198+
Context: ctx,
169199
})
170200

171201
if err != nil {
@@ -233,7 +263,7 @@ func Subscribe(p SubscribeParams) *ResultIterator {
233263
Source: p.RootValue,
234264
Args: args,
235265
Info: info,
236-
Context: sctx,
266+
Context: ctx,
237267
})
238268
if err != nil {
239269
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
@@ -249,12 +279,14 @@ func Subscribe(p SubscribeParams) *ResultIterator {
249279
}
250280

251281
switch fieldResult.(type) {
252-
case chan interface{}:
282+
case *Subscriber:
283+
sub := fieldResult.(*Subscriber)
253284
for {
254285
select {
255-
case <-sctx.Done():
286+
case <-doneChannel:
287+
sub.done <- true
256288
return
257-
case res := <-fieldResult.(chan interface{}):
289+
case res := <-sub.message:
258290
resultChannel <- mapSourceToResponse(res)
259291
}
260292
}
@@ -265,5 +297,5 @@ func Subscribe(p SubscribeParams) *ResultIterator {
265297
}()
266298

267299
// return a result iterator
268-
return NewResultIterator(sctx, cancelFunc, resultChannel)
300+
return NewResultIterator(doneChannel, resultChannel)
269301
}

subscription_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ func TestSubscription(t *testing.T) {
5252
return fmt.Sprintf("count=%v", p.Source), nil
5353
},
5454
Subscribe: func(p ResolveParams) (interface{}, error) {
55-
return m, nil
55+
sub := NewSubscriber(m, make(chan interface{}))
56+
return sub, nil
5657
},
5758
},
5859
"watch_should_fail": &Field{

0 commit comments

Comments
 (0)