Skip to content

Commit c991585

Browse files
committed
rewritten function signature and more tests
1 parent a55996a commit c991585

File tree

3 files changed

+552
-160
lines changed

3 files changed

+552
-160
lines changed

subscription.go

Lines changed: 127 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,66 +5,143 @@ import (
55
"fmt"
66

77
"github.com/graphql-go/graphql/gqlerrors"
8-
"github.com/graphql-go/graphql/language/ast"
8+
"github.com/graphql-go/graphql/language/parser"
9+
"github.com/graphql-go/graphql/language/source"
910
)
1011

1112
// SubscribeParams parameters for subscribing
1213
type SubscribeParams struct {
13-
Schema Schema
14-
Document *ast.Document
15-
RootValue interface{}
16-
ContextValue context.Context
14+
Schema Schema
15+
RequestString string
16+
RootValue interface{}
17+
// ContextValue context.Context
1718
VariableValues map[string]interface{}
1819
OperationName string
1920
FieldResolver FieldResolveFn
2021
FieldSubscriber FieldResolveFn
2122
}
2223

24+
// SubscriptableSchema implements `graphql-transport-ws` `GraphQLService` interface: https://github.com/graph-gophers/graphql-transport-ws/blob/40c0484322990a129cac2f2d2763c3315230280c/graphqlws/internal/connection/connection.go#L53
25+
type SubscriptableSchema struct {
26+
Schema Schema
27+
RootObject map[string]interface{}
28+
}
29+
30+
func (self *SubscriptableSchema) Subscribe(ctx context.Context, queryString string, operationName string, variables map[string]interface{}) (<-chan *Result, error) {
31+
c := Subscribe(Params{
32+
Schema: self.Schema,
33+
Context: ctx,
34+
OperationName: operationName,
35+
RequestString: queryString,
36+
RootObject: self.RootObject,
37+
VariableValues: variables,
38+
})
39+
return c, nil
40+
}
41+
2342
// Subscribe performs a subscribe operation
24-
func Subscribe(ctx context.Context, p SubscribeParams) chan *Result {
43+
func Subscribe(p Params) chan *Result {
44+
45+
source := source.NewSource(&source.Source{
46+
Body: []byte(p.RequestString),
47+
Name: "GraphQL request",
48+
})
49+
50+
// TODO run extensions hooks
51+
52+
// parse the source
53+
AST, err := parser.Parse(parser.ParseParams{Source: source})
54+
if err != nil {
55+
56+
// merge the errors from extensions and the original error from parser
57+
return sendOneResultandClose(&Result{
58+
Errors: gqlerrors.FormatErrors(err),
59+
})
60+
}
61+
62+
// validate document
63+
validationResult := ValidateDocument(&p.Schema, AST, nil)
64+
65+
if !validationResult.IsValid {
66+
// run validation finish functions for extensions
67+
return sendOneResultandClose(&Result{
68+
Errors: validationResult.Errors,
69+
})
70+
71+
}
72+
return ExecuteSubscription(ExecuteParams{
73+
Schema: p.Schema,
74+
Root: p.RootObject,
75+
AST: AST,
76+
OperationName: p.OperationName,
77+
Args: p.VariableValues,
78+
Context: p.Context,
79+
})
80+
}
81+
82+
func sendOneResultandClose(res *Result) chan *Result {
2583
resultChannel := make(chan *Result)
84+
resultChannel <- res
85+
close(resultChannel)
86+
return resultChannel
87+
}
88+
89+
func ExecuteSubscription(p ExecuteParams) chan *Result {
90+
91+
if p.Context == nil {
92+
p.Context = context.Background()
93+
}
94+
95+
// TODO run executionDidStart functions from extensions
2696

2797
var mapSourceToResponse = func(payload interface{}) *Result {
2898
return Execute(ExecuteParams{
2999
Schema: p.Schema,
30100
Root: payload,
31-
AST: p.Document,
101+
AST: p.AST,
32102
OperationName: p.OperationName,
33-
Args: p.VariableValues,
34-
Context: p.ContextValue,
103+
Args: p.Args,
104+
Context: p.Context,
35105
})
36106
}
37-
107+
var resultChannel = make(chan *Result)
38108
go func() {
39-
result := &Result{}
40109
defer func() {
41110
if err := recover(); err != nil {
42-
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
43-
resultChannel <- result
111+
e, ok := err.(error)
112+
if !ok {
113+
return
114+
}
115+
sendOneResultandClose(&Result{
116+
Errors: gqlerrors.FormatErrors(e),
117+
})
44118
}
45-
close(resultChannel)
119+
// close(resultChannel)
120+
return
46121
}()
47122

48123
exeContext, err := buildExecutionContext(buildExecutionCtxParams{
49124
Schema: p.Schema,
50-
Root: p.RootValue,
51-
AST: p.Document,
125+
Root: p.Root,
126+
AST: p.AST,
52127
OperationName: p.OperationName,
53-
Args: p.VariableValues,
54-
Result: result,
55-
Context: p.ContextValue,
128+
Args: p.Args,
129+
Result: &Result{}, // TODO what is this?
130+
Context: p.Context,
56131
})
57132

58133
if err != nil {
59-
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
60-
resultChannel <- result
134+
sendOneResultandClose(&Result{
135+
Errors: gqlerrors.FormatErrors(err),
136+
})
61137
return
62138
}
63139

64140
operationType, err := getOperationRootType(p.Schema, exeContext.Operation)
65141
if err != nil {
66-
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
67-
resultChannel <- result
142+
sendOneResultandClose(&Result{
143+
Errors: gqlerrors.FormatErrors(err),
144+
})
68145
return
69146
}
70147

@@ -85,18 +162,19 @@ func Subscribe(ctx context.Context, p SubscribeParams) chan *Result {
85162
fieldDef := getFieldDef(p.Schema, operationType, fieldName)
86163

87164
if fieldDef == nil {
88-
err := fmt.Errorf("the subscription field %q is not defined", fieldName)
89-
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
90-
resultChannel <- result
165+
sendOneResultandClose(&Result{
166+
Errors: gqlerrors.FormatErrors(fmt.Errorf("the subscription field %q is not defined", fieldName)),
167+
})
91168
return
92169
}
93170

94-
resolveFn := p.FieldSubscriber
171+
resolveFn := fieldDef.Subscribe
172+
95173
if resolveFn == nil {
96-
resolveFn = DefaultResolveFn
97-
}
98-
if fieldDef.Subscribe != nil {
99-
resolveFn = fieldDef.Subscribe
174+
sendOneResultandClose(&Result{
175+
Errors: gqlerrors.FormatErrors(fmt.Errorf("the subscription function %q is not defined", fieldName)),
176+
})
177+
return
100178
}
101179
fieldPath := &ResponsePath{
102180
Key: responseName,
@@ -117,38 +195,47 @@ func Subscribe(ctx context.Context, p SubscribeParams) chan *Result {
117195
}
118196

119197
fieldResult, err := resolveFn(ResolveParams{
120-
Source: p.RootValue,
198+
Source: p.Root,
121199
Args: args,
122200
Info: info,
123-
Context: p.ContextValue,
201+
Context: p.Context,
124202
})
125203
if err != nil {
126-
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
127-
resultChannel <- result
204+
sendOneResultandClose(&Result{
205+
Errors: gqlerrors.FormatErrors(err),
206+
})
128207
return
129208
}
130209

131210
if fieldResult == nil {
132-
err := fmt.Errorf("no field result")
133-
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
134-
resultChannel <- result
211+
sendOneResultandClose(&Result{
212+
Errors: gqlerrors.FormatErrors(fmt.Errorf("no field result")),
213+
})
135214
return
136215
}
137216

138217
switch fieldResult.(type) {
139218
case chan interface{}:
140219
sub := fieldResult.(chan interface{})
220+
defer close(resultChannel)
141221
for {
142222
select {
143-
case <-ctx.Done():
223+
case <-p.Context.Done():
224+
println("context cancelled")
225+
// TODO send the context error to the resultchannel
144226
return
145227

146-
case res := <-sub:
228+
case res, more := <-sub:
229+
if !more {
230+
return
231+
}
147232
resultChannel <- mapSourceToResponse(res)
148233
}
149234
}
150235
default:
236+
fmt.Println(fieldResult)
151237
resultChannel <- mapSourceToResponse(fieldResult)
238+
close(resultChannel)
152239
return
153240
}
154241
}()

0 commit comments

Comments
 (0)