Skip to content

Commit 4a374e3

Browse files
ensure subscribe and resolve use the same cancellable context
1 parent 9cf0da7 commit 4a374e3

File tree

2 files changed

+50
-16
lines changed

2 files changed

+50
-16
lines changed

executor.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"reflect"
8+
"sort"
89
"strings"
910

1011
"github.com/graphql-go/graphql/gqlerrors"
@@ -254,7 +255,9 @@ func executeFieldsSerially(p executeFieldsParams) *Result {
254255
}
255256

256257
finalResults := make(map[string]interface{}, len(p.Fields))
257-
for responseName, fieldASTs := range p.Fields {
258+
for _, orderedField := range orderedFields(p.Fields) {
259+
responseName := orderedField.responseName
260+
fieldASTs := orderedField.fieldASTs
258261
fieldPath := p.Path.WithKey(responseName)
259262
resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs, fieldPath)
260263
if state.hasNoFieldDefs {
@@ -1038,3 +1041,39 @@ func getFieldDef(schema Schema, parentType *Object, fieldName string) *FieldDefi
10381041
}
10391042
return parentType.Fields()[fieldName]
10401043
}
1044+
1045+
// contains field information that will be placed in an ordered slice
1046+
type orderedField struct {
1047+
responseName string
1048+
fieldASTs []*ast.Field
1049+
}
1050+
1051+
// orders fields from a fields map by location in the source
1052+
func orderedFields(fields map[string][]*ast.Field) []*orderedField {
1053+
orderedFields := []*orderedField{}
1054+
fieldMap := map[int]*orderedField{}
1055+
startLocs := []int{}
1056+
1057+
for responseName, fieldASTs := range fields {
1058+
// find the lowest location in the current fieldASTs
1059+
lowest := -1
1060+
for _, fieldAST := range fieldASTs {
1061+
loc := fieldAST.GetLoc().Start
1062+
if lowest == -1 || loc < lowest {
1063+
lowest = loc
1064+
}
1065+
}
1066+
startLocs = append(startLocs, lowest)
1067+
fieldMap[lowest] = &orderedField{
1068+
responseName: responseName,
1069+
fieldASTs: fieldASTs,
1070+
}
1071+
}
1072+
1073+
sort.Ints(startLocs)
1074+
for _, startLoc := range startLocs {
1075+
orderedFields = append(orderedFields, fieldMap[startLoc])
1076+
}
1077+
1078+
return orderedFields
1079+
}

subscription.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,13 @@ type ResultIterator struct {
3939
}
4040

4141
// NewResultIterator creates a new iterator and starts handling message on the result channel
42-
func NewResultIterator(ctx context.Context, ch chan *Result) *ResultIterator {
43-
if ctx == nil {
44-
ctx = context.Background()
45-
}
46-
47-
cctx, cancelFunc := context.WithCancel(ctx)
42+
func NewResultIterator(ctx context.Context, cancelFunc context.CancelFunc, ch chan *Result) *ResultIterator {
4843
iterator := &ResultIterator{
4944
currentHandlerID: 0,
5045
count: 0,
51-
ctx: cctx,
52-
ch: ch,
46+
ctx: ctx,
5347
cancelFunc: cancelFunc,
48+
ch: ch,
5449
cancelled: false,
5550
handlers: map[int64]*subscriptionHanlderConfig{},
5651
}
@@ -140,14 +135,16 @@ func Subscribe(p SubscribeParams) *ResultIterator {
140135
ctx = context.Background()
141136
}
142137

138+
sctx, cancelFunc := context.WithCancel(ctx)
139+
143140
var mapSourceToResponse = func(payload interface{}) *Result {
144141
return Execute(ExecuteParams{
145142
Schema: p.Schema,
146143
Root: payload,
147144
AST: p.Document,
148145
OperationName: p.OperationName,
149146
Args: p.VariableValues,
150-
Context: p.ContextValue,
147+
Context: sctx,
151148
})
152149
}
153150

@@ -168,7 +165,7 @@ func Subscribe(p SubscribeParams) *ResultIterator {
168165
OperationName: p.OperationName,
169166
Args: p.VariableValues,
170167
Result: result,
171-
Context: p.ContextValue,
168+
Context: sctx,
172169
})
173170

174171
if err != nil {
@@ -236,7 +233,7 @@ func Subscribe(p SubscribeParams) *ResultIterator {
236233
Source: p.RootValue,
237234
Args: args,
238235
Info: info,
239-
Context: exeContext.Context,
236+
Context: sctx,
240237
})
241238
if err != nil {
242239
result.Errors = append(result.Errors, gqlerrors.FormatError(err.(error)))
@@ -255,11 +252,9 @@ func Subscribe(p SubscribeParams) *ResultIterator {
255252
case chan interface{}:
256253
for {
257254
select {
258-
case <-ctx.Done():
259-
fmt.Printf("done context called")
255+
case <-sctx.Done():
260256
return
261257
case res := <-fieldResult.(chan interface{}):
262-
263258
resultChannel <- mapSourceToResponse(res)
264259
}
265260
}
@@ -270,5 +265,5 @@ func Subscribe(p SubscribeParams) *ResultIterator {
270265
}()
271266

272267
// return a result iterator
273-
return NewResultIterator(p.ContextValue, resultChannel)
268+
return NewResultIterator(sctx, cancelFunc, resultChannel)
274269
}

0 commit comments

Comments
 (0)