@@ -9,6 +9,30 @@ import (
9
9
"github.com/graphql-go/graphql/language/ast"
10
10
)
11
11
12
+ // Subscriber subscriber
13
+ type Subscriber struct {
14
+ message chan interface {}
15
+ done chan interface {}
16
+ }
17
+
18
+ // Message returns the subscriber message channel
19
+ func (c * Subscriber ) Message () chan interface {} {
20
+ return c .message
21
+ }
22
+
23
+ // Done returns the subscriber done channel
24
+ func (c * Subscriber ) Done () chan interface {} {
25
+ return c .done
26
+ }
27
+
28
+ // NewSubscriber creates a new subscriber
29
+ func NewSubscriber (message , done chan interface {}) * Subscriber {
30
+ return & Subscriber {
31
+ message : message ,
32
+ done : done ,
33
+ }
34
+ }
35
+
12
36
// ResultIteratorParams parameters passed to the result iterator handler
13
37
type ResultIteratorParams struct {
14
38
ResultCount int64 // number of results this iterator has processed
@@ -30,21 +54,28 @@ type subscriptionHanlderConfig struct {
30
54
type ResultIterator struct {
31
55
currentHandlerID int64
32
56
count int64
33
- wg sync.WaitGroup
34
- ctx context.Context
57
+ mx sync.Mutex
35
58
ch chan * Result
36
- cancelFunc context.CancelFunc
59
+ iterDone chan interface {}
60
+ subDone chan interface {}
37
61
cancelled bool
38
62
handlers map [int64 ]* subscriptionHanlderConfig
39
63
}
40
64
65
+ func (c * ResultIterator ) incrimentCount () int64 {
66
+ c .mx .Lock ()
67
+ defer c .mx .Unlock ()
68
+ c .count ++
69
+ return c .count
70
+ }
71
+
41
72
// NewResultIterator creates a new iterator and starts handling message on the result channel
42
- func NewResultIterator (ctx context. Context , cancelFunc context. CancelFunc , ch chan * Result ) * ResultIterator {
73
+ func NewResultIterator (subDone chan interface {} , ch chan * Result ) * ResultIterator {
43
74
iterator := & ResultIterator {
44
75
currentHandlerID : 0 ,
45
76
count : 0 ,
46
- ctx : ctx ,
47
- cancelFunc : cancelFunc ,
77
+ iterDone : make ( chan interface {}) ,
78
+ subDone : subDone ,
48
79
ch : ch ,
49
80
cancelled : false ,
50
81
handlers : map [int64 ]* subscriptionHanlderConfig {},
@@ -53,20 +84,18 @@ func NewResultIterator(ctx context.Context, cancelFunc context.CancelFunc, ch ch
53
84
go func () {
54
85
for {
55
86
select {
56
- case <- iterator .ctx .Done ():
87
+ case <- iterator .iterDone :
88
+ subDone <- true
57
89
return
58
90
case res := <- iterator .ch :
59
91
if iterator .cancelled {
60
92
return
61
93
}
62
- iterator .wg .Wait ()
63
- iterator .wg .Add (1 )
64
- iterator .count ++
65
- iterator .wg .Done ()
94
+
95
+ count := iterator .incrimentCount ()
66
96
for _ , h := range iterator .handlers {
67
- iterator .wg .Wait ()
68
97
h .handler (ResultIteratorParams {
69
- ResultCount : iterator . count ,
98
+ ResultCount : int64 ( count ) ,
70
99
Result : res ,
71
100
Done : h .doneFunc ,
72
101
Cancel : iterator .Cancel ,
@@ -81,7 +110,9 @@ func NewResultIterator(ctx context.Context, cancelFunc context.CancelFunc, ch ch
81
110
82
111
// adds a new handler
83
112
func (c * ResultIterator ) addHandler (handler ResultIteratorFn ) {
84
- c .wg .Add (1 )
113
+ c .mx .Lock ()
114
+ defer c .mx .Unlock ()
115
+
85
116
handlerID := c .currentHandlerID + 1
86
117
c .currentHandlerID = handlerID
87
118
c .handlers [handlerID ] = & subscriptionHanlderConfig {
@@ -90,17 +121,17 @@ func (c *ResultIterator) addHandler(handler ResultIteratorFn) {
90
121
c .removeHandler (handlerID )
91
122
},
92
123
}
93
- c .wg .Done ()
94
124
}
95
125
96
126
// removes a handler and cancels if no more handlers exist
97
127
func (c * ResultIterator ) removeHandler (handlerID int64 ) {
98
- c .wg .Add (1 )
128
+ c .mx .Lock ()
129
+ defer c .mx .Unlock ()
130
+
99
131
delete (c .handlers , handlerID )
100
132
if len (c .handlers ) == 0 {
101
133
c .Cancel ()
102
134
}
103
- c .wg .Done ()
104
135
}
105
136
106
137
// ForEach adds a handler and handles each message as they come
@@ -111,7 +142,7 @@ func (c *ResultIterator) ForEach(handler ResultIteratorFn) {
111
142
// Cancel cancels the iterator
112
143
func (c * ResultIterator ) Cancel () {
113
144
c .cancelled = true
114
- c .cancelFunc ()
145
+ c .iterDone <- true
115
146
}
116
147
117
148
// SubscribeParams parameters for subscribing
@@ -129,30 +160,29 @@ type SubscribeParams struct {
129
160
// Subscribe performs a subscribe operation
130
161
func Subscribe (p SubscribeParams ) * ResultIterator {
131
162
resultChannel := make (chan * Result )
163
+ doneChannel := make (chan interface {})
132
164
// Use background context if no context was provided
133
165
ctx := p .ContextValue
134
166
if ctx == nil {
135
167
ctx = context .Background ()
136
168
}
137
169
138
- sctx , cancelFunc := context .WithCancel (ctx )
139
-
140
170
var mapSourceToResponse = func (payload interface {}) * Result {
141
171
return Execute (ExecuteParams {
142
172
Schema : p .Schema ,
143
173
Root : payload ,
144
174
AST : p .Document ,
145
175
OperationName : p .OperationName ,
146
176
Args : p .VariableValues ,
147
- Context : sctx ,
177
+ Context : ctx ,
148
178
})
149
179
}
150
180
151
181
go func () {
152
-
153
182
result := & Result {}
154
183
defer func () {
155
184
if err := recover (); err != nil {
185
+ fmt .Println ("SUBSCRIPTION RECOVERER" , err )
156
186
result .Errors = append (result .Errors , gqlerrors .FormatError (err .(error )))
157
187
}
158
188
resultChannel <- result
@@ -165,7 +195,7 @@ func Subscribe(p SubscribeParams) *ResultIterator {
165
195
OperationName : p .OperationName ,
166
196
Args : p .VariableValues ,
167
197
Result : result ,
168
- Context : sctx ,
198
+ Context : ctx ,
169
199
})
170
200
171
201
if err != nil {
@@ -233,7 +263,7 @@ func Subscribe(p SubscribeParams) *ResultIterator {
233
263
Source : p .RootValue ,
234
264
Args : args ,
235
265
Info : info ,
236
- Context : sctx ,
266
+ Context : ctx ,
237
267
})
238
268
if err != nil {
239
269
result .Errors = append (result .Errors , gqlerrors .FormatError (err .(error )))
@@ -249,12 +279,14 @@ func Subscribe(p SubscribeParams) *ResultIterator {
249
279
}
250
280
251
281
switch fieldResult .(type ) {
252
- case chan interface {}:
282
+ case * Subscriber :
283
+ sub := fieldResult .(* Subscriber )
253
284
for {
254
285
select {
255
- case <- sctx .Done ():
286
+ case <- doneChannel :
287
+ sub .done <- true
256
288
return
257
- case res := <- fieldResult .( chan interface {}) :
289
+ case res := <- sub . message :
258
290
resultChannel <- mapSourceToResponse (res )
259
291
}
260
292
}
@@ -265,5 +297,5 @@ func Subscribe(p SubscribeParams) *ResultIterator {
265
297
}()
266
298
267
299
// return a result iterator
268
- return NewResultIterator (sctx , cancelFunc , resultChannel )
300
+ return NewResultIterator (doneChannel , resultChannel )
269
301
}
0 commit comments