Skip to content

Commit bd370d7

Browse files
added subscribe support
1 parent 24963e0 commit bd370d7

File tree

3 files changed

+354
-0
lines changed

3 files changed

+354
-0
lines changed

definition.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ func defineFieldMap(ttype Named, fieldMap Fields) (FieldDefinitionMap, error) {
534534
Description: field.Description,
535535
Type: field.Type,
536536
Resolve: field.Resolve,
537+
Subscribe: field.Subscribe,
537538
DeprecationReason: field.DeprecationReason,
538539
}
539540

@@ -606,6 +607,7 @@ type Field struct {
606607
Type Output `json:"type"`
607608
Args FieldConfigArgument `json:"args"`
608609
Resolve FieldResolveFn `json:"-"`
610+
Subscribe FieldResolveFn `json:"-"`
609611
DeprecationReason string `json:"deprecationReason"`
610612
Description string `json:"description"`
611613
}
@@ -625,6 +627,7 @@ type FieldDefinition struct {
625627
Type Output `json:"type"`
626628
Args []*Argument `json:"args"`
627629
Resolve FieldResolveFn `json:"-"`
630+
Subscribe FieldResolveFn `json:"-"`
628631
DeprecationReason string `json:"deprecationReason"`
629632
}
630633

subscription.go

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
package graphql
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/graphql-go/graphql/gqlerrors"
8+
"github.com/graphql-go/graphql/language/ast"
9+
)
10+
11+
type ResultIteratorFn func(count int64, result *Result, doneFunc func())
12+
13+
type ResultIterator struct {
14+
count int64
15+
ctx context.Context
16+
ch chan *Result
17+
cancelFunc context.CancelFunc
18+
cancelled bool
19+
handlers []ResultIteratorFn
20+
}
21+
22+
func NewResultIterator(ctx context.Context, ch chan *Result) *ResultIterator {
23+
if ctx == nil {
24+
ctx = context.Background()
25+
}
26+
27+
cctx, cancelFunc := context.WithCancel(ctx)
28+
iterator := &ResultIterator{
29+
count: 0,
30+
ctx: cctx,
31+
ch: ch,
32+
cancelFunc: cancelFunc,
33+
cancelled: false,
34+
handlers: []ResultIteratorFn{},
35+
}
36+
37+
go func() {
38+
for {
39+
select {
40+
case <-iterator.ctx.Done():
41+
return
42+
case res := <-iterator.ch:
43+
if iterator.cancelled {
44+
return
45+
}
46+
iterator.count += 1
47+
for _, handler := range iterator.handlers {
48+
handler(iterator.count, res, iterator.Done)
49+
}
50+
}
51+
}
52+
}()
53+
54+
return iterator
55+
}
56+
57+
func (c *ResultIterator) ForEach(handler ResultIteratorFn) {
58+
c.handlers = append(c.handlers, handler)
59+
}
60+
61+
func (c *ResultIterator) Done() {
62+
c.cancelled = true
63+
c.cancelFunc()
64+
}
65+
66+
type SubscribeParams struct {
67+
Schema Schema
68+
Document *ast.Document
69+
RootValue interface{}
70+
ContextValue context.Context
71+
VariableValues map[string]interface{}
72+
OperationName string
73+
FieldResolver FieldResolveFn
74+
FieldSubscriber FieldResolveFn
75+
}
76+
77+
// Subscribe performs a subscribe operation
78+
func Subscribe(p SubscribeParams) *ResultIterator {
79+
resultChannel := make(chan *Result)
80+
// Use background context if no context was provided
81+
ctx := p.ContextValue
82+
if ctx == nil {
83+
ctx = context.Background()
84+
}
85+
86+
var mapSourceToResponse = func(payload interface{}) *Result {
87+
return Execute(ExecuteParams{
88+
Schema: p.Schema,
89+
Root: payload,
90+
AST: p.Document,
91+
OperationName: p.OperationName,
92+
Args: p.VariableValues,
93+
Context: p.ContextValue,
94+
})
95+
}
96+
97+
go func() {
98+
99+
result := &Result{}
100+
defer func() {
101+
if err := recover(); err != nil {
102+
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
103+
}
104+
resultChannel <- result
105+
}()
106+
107+
exeContext, err := buildExecutionContext(buildExecutionCtxParams{
108+
Schema: p.Schema,
109+
Root: p.RootValue,
110+
AST: p.Document,
111+
OperationName: p.OperationName,
112+
Args: p.VariableValues,
113+
Result: result,
114+
Context: p.ContextValue,
115+
})
116+
117+
if err != nil {
118+
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
119+
resultChannel <- result
120+
return
121+
}
122+
123+
operationType, err := getOperationRootType(p.Schema, exeContext.Operation)
124+
if err != nil {
125+
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
126+
resultChannel <- result
127+
return
128+
}
129+
130+
fields := collectFields(collectFieldsParams{
131+
ExeContext: exeContext,
132+
RuntimeType: operationType,
133+
SelectionSet: exeContext.Operation.GetSelectionSet(),
134+
})
135+
136+
responseNames := []string{}
137+
for name := range fields {
138+
responseNames = append(responseNames, name)
139+
}
140+
responseName := responseNames[0]
141+
fieldNodes := fields[responseName]
142+
fieldNode := fieldNodes[0]
143+
fieldName := fieldNode.Name.Value
144+
fieldDef := getFieldDef(p.Schema, operationType, fieldName)
145+
146+
if fieldDef == nil {
147+
err := fmt.Errorf("the subscription field %q is not defined", fieldName)
148+
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
149+
resultChannel <- result
150+
return
151+
}
152+
153+
resolveFn := p.FieldSubscriber
154+
if resolveFn == nil {
155+
resolveFn = DefaultResolveFn
156+
}
157+
if fieldDef.Subscribe != nil {
158+
resolveFn = fieldDef.Subscribe
159+
}
160+
fieldPath := &ResponsePath{
161+
Key: responseName,
162+
}
163+
164+
args := getArgumentValues(fieldDef.Args, fieldNode.Arguments, exeContext.VariableValues)
165+
info := ResolveInfo{
166+
FieldName: fieldName,
167+
FieldASTs: fieldNodes,
168+
Path: fieldPath,
169+
ReturnType: fieldDef.Type,
170+
ParentType: operationType,
171+
Schema: p.Schema,
172+
Fragments: exeContext.Fragments,
173+
RootValue: exeContext.Root,
174+
Operation: exeContext.Operation,
175+
VariableValues: exeContext.VariableValues,
176+
}
177+
178+
fieldResult, err := resolveFn(ResolveParams{
179+
Source: p.RootValue,
180+
Args: args,
181+
Info: info,
182+
Context: exeContext.Context,
183+
})
184+
if err != nil {
185+
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
186+
resultChannel <- result
187+
return
188+
}
189+
190+
if fieldResult == nil {
191+
err := fmt.Errorf("no field result")
192+
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
193+
resultChannel <- result
194+
return
195+
}
196+
197+
switch fieldResult.(type) {
198+
case chan interface{}:
199+
for {
200+
select {
201+
case <-ctx.Done():
202+
fmt.Printf("done context called")
203+
return
204+
case res := <-fieldResult.(chan interface{}):
205+
206+
resultChannel <- mapSourceToResponse(res)
207+
}
208+
}
209+
default:
210+
resultChannel <- mapSourceToResponse(fieldResult)
211+
return
212+
}
213+
}()
214+
215+
// return a result iterator
216+
return NewResultIterator(p.ContextValue, resultChannel)
217+
}

subscription_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package graphql
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
"time"
8+
9+
"github.com/graphql-go/graphql/language/parser"
10+
"github.com/graphql-go/graphql/language/source"
11+
)
12+
13+
func TestSubscription(t *testing.T) {
14+
var maxPublish = 5
15+
m := make(chan interface{})
16+
17+
source1 := source.NewSource(&source.Source{
18+
Body: []byte(`subscription {
19+
watch_count
20+
}`),
21+
Name: "GraphQL request",
22+
})
23+
24+
source2 := source.NewSource(&source.Source{
25+
Body: []byte(`subscription {
26+
watch_should_fail
27+
}`),
28+
Name: "GraphQL request",
29+
})
30+
31+
document1, _ := parser.Parse(parser.ParseParams{Source: source1})
32+
document2, _ := parser.Parse(parser.ParseParams{Source: source2})
33+
34+
schema, err := NewSchema(SchemaConfig{
35+
Query: NewObject(ObjectConfig{
36+
Name: "Query",
37+
Fields: Fields{
38+
"hello": &Field{
39+
Type: String,
40+
Resolve: func(p ResolveParams) (interface{}, error) {
41+
return "world", nil
42+
},
43+
},
44+
},
45+
}),
46+
Subscription: NewObject(ObjectConfig{
47+
Name: "Subscription",
48+
Fields: Fields{
49+
"watch_count": &Field{
50+
Type: String,
51+
Resolve: func(p ResolveParams) (interface{}, error) {
52+
return fmt.Sprintf("count=%v", p.Source), nil
53+
},
54+
Subscribe: func(p ResolveParams) (interface{}, error) {
55+
return m, nil
56+
},
57+
},
58+
"watch_should_fail": &Field{
59+
Type: String,
60+
Resolve: func(p ResolveParams) (interface{}, error) {
61+
return fmt.Sprintf("count=%v", p.Source), nil
62+
},
63+
Subscribe: func(p ResolveParams) (interface{}, error) {
64+
return nil, nil
65+
},
66+
},
67+
},
68+
}),
69+
})
70+
71+
if err != nil {
72+
t.Errorf("failed to create schema: %v", err)
73+
return
74+
}
75+
76+
failIterator := Subscribe(SubscribeParams{
77+
Schema: schema,
78+
Document: document2,
79+
})
80+
81+
// test a subscribe that should fail due to no return value
82+
failIterator.ForEach(func(count int64, res *Result, doneFunc func()) {
83+
if !res.HasErrors() {
84+
t.Errorf("subscribe failed to catch nil result from subscribe")
85+
doneFunc()
86+
return
87+
}
88+
doneFunc()
89+
return
90+
})
91+
92+
resultIterator := Subscribe(SubscribeParams{
93+
Schema: schema,
94+
Document: document1,
95+
ContextValue: context.Background(),
96+
})
97+
98+
resultIterator.ForEach(func(count int64, res *Result, doneFunc func()) {
99+
if res.HasErrors() {
100+
t.Errorf("subscribe error(s): %v", res.Errors)
101+
doneFunc()
102+
return
103+
}
104+
105+
if res.Data != nil {
106+
data := res.Data.(map[string]interface{})["watch_count"]
107+
expected := fmt.Sprintf("count=%d", count)
108+
actual := fmt.Sprintf("%v", data)
109+
if actual != expected {
110+
t.Errorf("subscription result error: expected %q, actual %q", expected, actual)
111+
doneFunc()
112+
return
113+
}
114+
115+
// test the done func by quitting after 3 iterations
116+
// the publisher will publish up to 5
117+
if count >= int64(maxPublish-2) {
118+
doneFunc()
119+
return
120+
}
121+
}
122+
})
123+
124+
// start publishing
125+
go func() {
126+
for i := 1; i <= maxPublish; i++ {
127+
time.Sleep(200 * time.Millisecond)
128+
m <- i
129+
}
130+
}()
131+
132+
// give time for the test to complete
133+
time.Sleep(1 * time.Second)
134+
}

0 commit comments

Comments
 (0)